Exemple #1
0
    def load_model(self):
        parser = argparse.ArgumentParser()
        args = parser.parse_args()
        args.output_encoded_layers = True
        args.output_attention_layers = True
        args.output_att_score = True
        args.output_att_sum = True
        self.args = args
        # 解析配置文件, 教师模型和student模型的vocab是不变的
        self.vocab_file = "albert_model/vocab.txt"
        # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型
        # student config:  config/chinese_bert_config_L4t.json
        # distil student model:  distil_model/gs8316.pkl
        self.bert_config_file_S = "albert_model/config.json"
        self.tuned_checkpoint_S = "trained_teacher_model/test_components.pkl"
        self.max_seq_length = 70
        # 预测的batch_size大小
        self.predict_batch_size = 64
        # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
        bert_config_S = AlbertConfig.from_json_file(self.bert_config_file_S)
        bert_config_S.num_labels = self.num_labels

        # 加载tokenizer
        tokenizer = BertTokenizer(vocab_file=self.vocab_file)

        # 加载模型
        model_S = AlbertSPC(bert_config_S)
        state_dict_S = torch.load(self.tuned_checkpoint_S, map_location=self.device)
        model_S.load_state_dict(state_dict_S)
        if self.verbose:
            print("模型已加载")

        return tokenizer, model_S
Exemple #2
0
    def load_macbert_model(self):
        parser = argparse.ArgumentParser()
        args = parser.parse_args()
        args.output_encoded_layers = True
        args.output_attention_layers = True
        args.output_att_score = True
        args.output_att_sum = True
        self.args = args
        # 解析配置文件, 教师模型和student模型的vocab是不变的
        self.vocab_file = "mac_bert_model/vocab.txt"
        # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型
        # student config:  config/chinese_bert_config_L4t.json
        # distil student model:  distil_model/gs8316.pkl
        self.bert_config_file_S = "mac_bert_model/config.json"
        self.tuned_checkpoint_S = "trained_teacher_model/macbert_teacher_max75len_5000.pkl"
        # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
        bert_config_S = AlbertConfig.from_json_file(self.bert_config_file_S)

        # 加载tokenizer
        tokenizer = BertTokenizer(vocab_file=self.vocab_file)

        # 加载模型
        model_S = AlbertSPC(bert_config_S)
        state_dict_S = torch.load(self.tuned_checkpoint_S,
                                  map_location=self.device)
        model_S.load_state_dict(state_dict_S)
        if self.verbose:
            print("模型已加载")
        self.predict_tokenizer = tokenizer
        self.predict_model = model_S
        logger.info(f"macbert预测模型加载完成")
    def __init__(
        self,
        pretrained_model_name=None,
        config_filename=None,
        vocab_size=None,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        max_position_embeddings=512,
    ):
        super().__init__()

        # Check that only one of pretrained_model_name, config_filename, and
        # vocab_size was passed in
        total = 0
        if pretrained_model_name is not None:
            total += 1
        if config_filename is not None:
            total += 1
        if vocab_size is not None:
            total += 1

        if total != 1:
            raise ValueError(
                "Only one of pretrained_model_name, vocab_size, "
                + "or config_filename should be passed into the "
                + "ALBERT constructor."
            )

        # TK: The following code checks the same once again.
        if vocab_size is not None:
            config = AlbertConfig(
                vocab_size_or_config_json_file=vocab_size,
                vocab_size=vocab_size,
                hidden_size=hidden_size,
                num_hidden_layers=num_hidden_layers,
                num_attention_heads=num_attention_heads,
                intermediate_size=intermediate_size,
                hidden_act=hidden_act,
                max_position_embeddings=max_position_embeddings,
            )
            model = AlbertModel(config)
        elif pretrained_model_name is not None:
            model = AlbertModel.from_pretrained(pretrained_model_name)
        elif config_filename is not None:
            config = AlbertConfig.from_json_file(config_filename)
            model = AlbertModel(config)
        else:
            raise ValueError(
                "Either pretrained_model_name or vocab_size must" + " be passed into the ALBERT constructor"
            )

        model.to(self._device)

        self.add_module("albert", model)
        self.config = model.config
        self._hidden_size = model.config.hidden_size
Exemple #4
0
def launch(training_flag, test_flag):
    tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
    if training_flag:
        model = AlbertForTokenClassification.from_pretrained(
            'albert-base-v2', num_labels=len(tags_vals))
        ## ---------12 . Optimizer -> weight regularization is  a solution to reduce the overfitting of a deep learning
        """ 
        Last keras optimization 2020 (rates from 0.01 seem to be best hyperparamater )for weight regularization for weights layers
            from keras.layers import LSTM
            from keras.regularizers import l2
        model.add(LSTM(32, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01), bias_regularizer=l2(0.01))) 
        Note :  BERT not include beta an gamma parametres for optimization
        """
        FULL_FINETUNING = True
        if FULL_FINETUNING:
            param_optimizer = list(model.named_parameters())
            no_decay = ['bias', 'gamma', 'beta']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay_rate':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay_rate':
                0.0
            }]
        else:
            param_optimizer = list(model.classifier.named_parameters())
            optimizer_grouped_parameters = [{
                "params": [p for n, p in param_optimizer]
            }]
        optimizer = Adam(optimizer_grouped_parameters, lr=args.lr)
        launch_training(training_path=args.training_data,
                        training_epochs=args.epochs,
                        valid_path=args.validate_data,
                        training_batch_size=1,
                        model=model,
                        model_path=model_path,
                        tokenizer=tokenizer,
                        optimizer=optimizer)
    if test_flag:
        if args.save:
            model_path = args.save + 'pytorch_model.bin'
            config = AlbertConfig.from_json_file(args.save + '/config.json')
            model = AlbertForTokenClassification.from_pretrained(args.save,
                                                                 config=config)
        else:
            model = AlbertForTokenClassification.from_pretrained(
                'albert-base-v2', num_labels=len(tags_vals))
        launch_test_directory(test_path=test_flag,
                              model=model,
                              tokenizer=tokenizer)
Exemple #5
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file,
                                     pytorch_dump_path):
    # Initialise PyTorch model
    config = AlbertConfig.from_json_file(albert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = AlbertForMaskedLM(config)
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
Exemple #6
0
    def load_train_model(self):
        """
        初始化训练的模型
        :return:
        """
        parser = argparse.ArgumentParser()
        args = parser.parse_args()
        args.output_encoded_layers = True
        args.output_attention_layers = True
        args.output_att_score = True
        args.output_att_sum = True
        self.learning_rate = 2e-05
        #学习率 warmup的比例
        self.warmup_proportion = 0.1
        self.num_train_epochs = 1
        #使用的学习率scheduler
        self.schedule = 'slanted_triangular'
        self.s_opt1 = 30.0
        self.s_opt2 = 0.0
        self.s_opt3 = 1.0
        self.weight_decay_rate = 0.01
        #训练多少epcoh保存一次模型
        self.ckpt_frequency = 1
        #模型和日志保存的位置
        self.output_dir = "output_root_dir/train_api"
        #梯度累积步数
        self.gradient_accumulation_steps = 1
        self.args = args
        # 解析配置文件, 教师模型和student模型的vocab是不变的
        self.vocab_file = "albert_model/vocab.txt"
        self.bert_config_file_S = "albert_model/config.json"
        self.tuned_checkpoint_S = "albert_model/pytorch_model.bin"
        # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
        bert_config_S = AlbertConfig.from_json_file(self.bert_config_file_S)

        # 加载tokenizer
        tokenizer = BertTokenizer(vocab_file=self.vocab_file)

        # 加载模型
        model_S = AlbertSPC(bert_config_S,
                            num_labels=self.num_labels,
                            args=self.args)
        state_dict_S = torch.load(self.tuned_checkpoint_S,
                                  map_location=self.device)
        state_weight = {
            k[5:]: v
            for k, v in state_dict_S.items() if k.startswith('bert.')
        }
        missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                       strict=False)
        #验证下参数没有丢失
        assert len(missing_keys) == 0
        self.train_tokenizer = tokenizer
        self.train_model = model_S
        logger.info(f"训练模型{self.tuned_checkpoint_S}加载完成")
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = AlbertConfig.from_json_file(albert_config_file)
    print(f"Building PyTorch model from configuration: {config}")
    model = AlbertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)
Exemple #8
0
def albert_convert_tf_checkpoint_to_pytorch(tf_checkpoint_path,
                                            albert_config_file,
                                            pytorch_dump_path):
    from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
    # Initialise PyTorch model
    config = AlbertConfig.from_json_file(albert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = AlbertForMaskedLM(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
Exemple #9
0
    def load_predict_model(
            self,
            type,
            model_file="trained_teacher_model/components_albert.pkl"):
        """
        :param type: 加载哪种类型的模型,是成分的,还是其它的
        :type type: 任意一种类型,加载不同的模型 "component","effect","fragrance","pack","skin","promotion","service","price"
        :param model_file:
        :type model_file:
        :return:
        :rtype:
        """
        parser = argparse.ArgumentParser()
        args = parser.parse_args()
        args.output_encoded_layers = True
        args.output_attention_layers = True
        args.output_att_score = True
        args.output_att_sum = True
        self.args = args
        # 解析配置文件, 教师模型和student模型的vocab是不变的
        self.vocab_file = "albert_model/vocab.txt"
        # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型
        # student config:  config/chinese_bert_config_L4t.json
        # distil student model:  distil_model/gs8316.pkl
        self.bert_config_file_S = "albert_model/config.json"
        self.tuned_checkpoint_S = model_file
        # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
        bert_config_S = AlbertConfig.from_json_file(self.bert_config_file_S)
        bert_config_S.num_labels = self.num_labels

        # 加载tokenizer
        tokenizer = BertTokenizer(vocab_file=self.vocab_file)

        # 加载模型
        model_S = AlbertSPC(bert_config_S)
        state_dict_S = torch.load(self.tuned_checkpoint_S,
                                  map_location=self.device)
        model_S.load_state_dict(state_dict_S)
        if self.verbose:
            print("模型已加载")
        self.predict_tokenizer[type] = tokenizer
        self.predict_model[type] = model_S
        logger.info(f"预测模型{model_file}加载完成")
Exemple #10
0
def main():
    #
    blockPrint()

    # setting device
    device = torch.device('cuda')

    #
    FullData = MR_Data.load_data('dataset/test.tsv', is_train_data=False)
    FullDataset = makeTorchDataSet(FullData, is_train_data=False)
    TestDataLoader = makeTorchDataLoader(FullDataset, batch_size=16)
    model_config = AlbertConfig.from_json_file(
        'model/albert-large-config.json')
    trained_model_file = '12-11-2019_09-17-05_ALSS_e5_a69.24226892192033'
    model = AlbertForSequenceClassification.from_pretrained(
        'train_models/' + trained_model_file + '/pytorch_model.bin',
        config=model_config)

    model.to(device)
    model.eval()

    f = open('submission.csv', 'w', encoding='utf-8')
    f.write('PhraseId,Sentiment\n')
    log("please waiting for predict ....")
    for batch_index, batch_dict in enumerate(TestDataLoader):
        batch_dict = tuple(t.to(device) for t in batch_dict)
        input_ids, phrase_ids = batch_dict
        outputs = model(input_ids)

        outputs = outputs[0].cpu()
        outputs = outputs.detach().numpy()
        # log(outputs)

        for i in range(len(outputs)):
            p_id = phrase_ids[i].item()
            s_level = np.argmax(outputs[i])
            # log("phrase_id",p_id,"segment_level",s_level)
            f.write(str(p_id) + ',' + str(s_level) + '\n')

    f.close()
import torch
import json
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("-t", "--type_of_model", default = 'albert', help = "pretrained LM type")
parser.add_argument("-p", "--path_to_pytorch_models", help = "path to pytorch_model")
parser.add_argument("--config_and_vocab", help = "path to config.json and vocab.model")
parser.add_argument("-s", "--step", type = str, help = "pretrained step")
parser.add_argument("-d", "--data", help = "path where you put your processed ontonotes data")
parser.add_argument("-o", "--output", help = "output file")
args = parser.parse_args()
print("Reconstruction. step = ", args.step)
if args.type_of_model == 'albert':
  tokenizer = AlbertTokenizer(os.path.join(args.config_and_vocab, args.type_of_model, 'vocab.model'))
  config = AlbertConfig.from_json_file(os.path.join(args.config_and_vocab, args.type_of_model, 'config.json'))
  config.output_hidden_states = True
  model = AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path = None,
    config = config,
    state_dict = torch.load(os.path.join(
      args.path_to_pytorch_models, args.type_of_model, 'pytorch_model_' + args.step + '.bin')))
elif args.type_of_model == 'bert':
  tokenizer = BertTokenizer(os.path.join(args.config_and_vocab, args.type_of_model, 'vocab.model'))
  config = BertConfig.from_json_file(os.path.join(args.config_and_vocab, args.type_of_model, 'config.json'))
  config.output_hidden_states = True
  model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path = None,
    config = config,
    state_dict = torch.load(os.path.join(
      args.path_to_pytorch_models, args.type_of_model, 'pytorch_model_' + args.step + '.bin')))
else:
  raise NotImplementedError("The given model type %s is not supported" % args.type_of_model)
def main():
    # setting device
    device = torch.device('cuda')

    #
    FullData = MR_Data.load_data('dataset/train.tsv')
    FullDataset = makeTorchDataSet(FullData)
    TrainDataset, TestDataset = splitDataset(FullDataset, 0.9)
    TrainDataLoader = makeTorchDataLoader(TrainDataset, batch_size=16)
    TestDataLoader = makeTorchDataLoader(TestDataset, batch_size=8)
    model_config = AlbertConfig.from_json_file(
        'model/albert-large-config.json')
    model = AlbertForSequenceClassification.from_pretrained(
        'model/albert-large-pytorch_model.bin', config=model_config)
    model.to(device)

    #
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=5e-6, eps=1e-8)

    model.zero_grad()

    try:
        for epoch in range(15):
            # train
            running_loss_val = 0.0
            running_acc = 0.0
            for batch_index, batch_dict in enumerate(TrainDataLoader):
                model.train()
                batch_dict = tuple(t.to(device) for t in batch_dict)
                outputs = model(batch_dict[0], labels=batch_dict[1])
                loss, logits = outputs[:2]
                loss.sum().backward()
                optimizer.step()
                model.zero_grad()

                # compute the loss
                loss_t = loss.item()
                running_loss_val += (loss_t -
                                     running_loss_val) / (batch_index + 1)

                # compute the accuracy
                acc_t = computeAccuracy(logits, batch_dict[1])
                running_acc += (acc_t - running_acc) / (batch_index + 1)

                # log
                if (batch_index % 50 == 0):
                    log(">> TRAIN << epoch:%2d batch:%4d loss:%2.4f acc:%3.4f"
                        % (epoch + 1, batch_index + 1, running_loss_val,
                           running_acc))

            # test
            running_loss_val = 0.0
            running_acc = 0.0
            for batch_index, batch_dict in enumerate(TestDataLoader):
                model.eval()
                batch_dict = tuple(t.to(device) for t in batch_dict)
                outputs = model(batch_dict[0], labels=batch_dict[1])
                loss, logits = outputs[:2]

                # compute the loss
                loss_t = loss.item()
                running_loss_val += (loss_t -
                                     running_loss_val) / (batch_index + 1)

                # compute the accuracy
                acc_t = computeAccuracy(logits, batch_dict[1])
                running_acc += (acc_t - running_acc) / (batch_index + 1)

                # log
                if (batch_index % 50 == 0):
                    log(">> TEST << epoch:%2d batch:%4d loss:%2.4f acc:%3.4f" %
                        (epoch + 1, batch_index + 1, running_loss_val,
                         running_acc))

            # save model
            saveModel(model,
                      'ALSS_e%s_a%s' % (str(epoch + 1), str(running_acc)))

    except KeyboardInterrupt:
        saveModel(
            model,
            'Interrupt_ALSS_e%s_a%s' % (str(epoch + 1), str(running_acc)))

    except Exception as e:
        print(e)
Exemple #13
0
def load_eval_model(args):
    vocab_path = f'./{args["model_dir"]}/{args["dataset"]}/{args["experiment_name"]}/vocab.txt'
    config_path = f'./{args["model_dir"]}/{args["dataset"]}/{args["experiment_name"]}/config.json'
    model_path = f'./{args["model_dir"]}/{args["dataset"]}/{args["experiment_name"]}/best_model_0.th'

    # Load for word2vec and fasttext
    if 'word2vec' in args['model_type'] or 'fasttext' in args['model_type']:
        emb_path = args['embedding_path'][args['model_type']]
        model, tokenizer = load_word_embedding_model(
            args['model_type'],
            args['task'],
            vocab_path,
            args['word_tokenizer_class'],
            emb_path,
            args['num_labels'],
            lower=args['lower'])
        return model, tokenizer

    # Load config & tokenizer
    if 'albert' in args['model_type']:
        config = AlbertConfig.from_json_file(config_path)
        tokenizer = BertTokenizer(vocab_path)
    elif 'babert' in args['model_type']:
        config = BertConfig.from_json_file(config_path)
        tokenizer = BertTokenizer(vocab_path)
    elif 'scratch' in args['model_type']:
        config = BertConfig.from_pretrained('bert-base-uncased')
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    elif 'bert-base-multilingual' in args['model_type']:
        config = BertConfig.from_pretrained(args['model_type'])
        tokenizer = BertTokenizer.from_pretrained(args['model_type'])
    elif 'xlm-mlm-100-1280' in args['model_type']:
        config = XLMConfig.from_pretrained(args['model_type'])
        tokenizer = XLMTokenizer.from_pretrained(args['model_type'])
    elif 'xlm-roberta' in args['model_type']:
        config = XLMRobertaConfig.from_pretrained(args['model_type'])
        tokenizer = XLMRobertaTokenizer.from_pretrained(args['model_type'])
    else:
        raise ValueError('Invalid `model_type` argument values')

    # Get model class
    base_cls, pred_cls = get_model_class(args['model_type'], args['task'])

    # Adjust config
    if type(args['num_labels']) == list:
        config.num_labels = max(args['num_labels'])
        config.num_labels_list = args['num_labels']
    else:
        config.num_labels = args['num_labels']

    # Instantiate model
    model = pred_cls(config=config)
    base_model = base_cls.from_pretrained(model_path,
                                          from_tf=False,
                                          config=config)

    # Plug pretrained base model to classification model
    if 'bert' in model.__dir__():
        model.bert = base_model
    elif 'albert' in model.__dir__():
        model.albert = base_model
    elif 'roberta' in model.__dir__():
        model.roberta = base_model
    elif 'transformer' in model.__dir__():
        model.transformer = base_model
    else:
        ValueError(
            'Model attribute not found, is there any change in the `transformers` library?'
        )

    return model, tokenizer
Exemple #14
0
def load_model(args):
    if 'albert-large-wwmlm-512' == args['model_checkpoint']:
        vocab_path = "../embeddings/albert-large-wwmlm-512/albert_large_model_bpe_wwmlm_512_vocab_uncased_30000.txt"
        tokenizer = BertTokenizer(vocab_path)
        config = AlbertConfig.from_json_file(
            "../embeddings/albert-large-wwmlm-512/albert_large_model_bpe_wwmlm_512_albert_large_config.json"
        )
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = AlbertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = AlbertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = AlbertForMultiLabelClassification(config)

        # Plug pretrained bert model
        albert_model = AlbertModel.from_pretrained(
            "../embeddings/albert-large-wwmlm-512/albert_large_model_bpe_wwmlm_512_pytorch_albert_large_512_629k.bin",
            from_tf=False,
            config=config)
        model.albert = albert_model
    elif 'albert-base-wwmlm-512' == args['model_checkpoint']:
        vocab_path = "../embeddings/albert-base-wwmlm-512/albert_base_model_bpe_wwmlm_512_vocab_uncased_30000.txt"
        config_path = "../embeddings/albert-base-wwmlm-512/albert_base_model_bpe_wwmlm_512_albert_base_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = AlbertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = AlbertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = AlbertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = AlbertForMultiLabelClassification(config)

        # Plug pretrained bert model
        albert_model = AlbertModel.from_pretrained(
            "../embeddings/albert-base-wwmlm-512/albert_base_model_bpe_wwmlm_512_pytorch_model_albert_base_162k.bin",
            from_tf=False,
            config=config)
        model.albert = albert_model

    elif 'albert-large-wwmlm-128' == args['model_checkpoint']:
        vocab_path = "../embeddings/albert-large-wwmlm-128/albert_large_model_bpe_wwmlm_128_vocab_uncased_30000.txt"
        config_path = "../embeddings/albert-large-wwmlm-128/albert_large_model_bpe_wwmlm_128_albert_large_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = AlbertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = AlbertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = AlbertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = AlbertForMultiLabelClassification(config)

        # Plug pretrained bert model
        albert_model = AlbertModel.from_pretrained(
            "../embeddings/albert-large-wwmlm-128/albert_large_model_bpe_wwmlm_128_pytorch_albert_large_128_500k.bin",
            from_tf=False,
            config=config)
        model.albert = albert_model

    elif 'babert-bpe-mlm-large-512' == args['model_checkpoint']:
        # babert_bpe
        # Prepare config & tokenizer
        vocab_path = "../embeddings/babert-bpe-mlm-large-512/babert_model_bpe_mlm_uncased_large_512_dup10-5_vocab_uncased_30522.txt"
        config_path = "../embeddings/babert-bpe-mlm-large-512/babert_model_bpe_mlm_uncased_large_512_dup10-5_bert_large_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = BertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config)

        # Plug pretrained bert model
        bert_model = BertForPreTraining.from_pretrained(
            "../embeddings/babert-bpe-mlm-large-512/babert_model_bpe_mlm_uncased_large_512_dup10-5_pytorch_babert_uncased_large_512_dup10-5_1120k.bin",
            config=config)
        model.bert = bert_model.bert

    elif 'albert-base-uncased-112500' == args['model_checkpoint']:
        vocab_path = "../embeddings/albert-base-uncased-112500/vocab.txt"
        config_path = "../embeddings/albert-base-uncased-112500/bert_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = AlbertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = AlbertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = AlbertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = AlbertForMultiLabelClassification(config)

        # Plug pretrained bert model
        albert_model = AlbertModel.from_pretrained(
            "../embeddings/albert-base-uncased-112500/albert_base_uncased_112500.bin",
            from_tf=False,
            config=config)
        model.albert = albert_model

    elif 'albert-base-uncased-96000' == args['model_checkpoint']:
        vocab_path = "../embeddings/albert-base-uncased-96000/vocab.txt"
        config_path = "../embeddings/albert-base-uncased-96000/bert_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = AlbertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = AlbertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = AlbertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = AlbertForMultiLabelClassification(config)

        # Plug pretrained bert model
        albert_model = AlbertModel.from_pretrained(
            "../embeddings/albert-base-uncased-96000/albert_base_uncased_96000.bin",
            from_tf=False,
            config=config)
        model.albert = albert_model

    elif 'albert-base-uncased-191k' == args['model_checkpoint']:
        vocab_path = "../embeddings/albert-base-uncased-191k/pytorch_models_albert_base_uncased_191500_vocab_uncased_30000.txt"
        config_path = "../embeddings/albert-base-uncased-191k/pytorch_models_albert_base_uncased_191500_albert_base_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = AlbertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = AlbertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = AlbertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = AlbertForMultiLabelClassification(config)

        # Plug pretrained bert model
        albert_model = AlbertModel.from_pretrained(
            "../embeddings/albert-base-uncased-191k/pytorch_models_albert_base_uncased_191500_pytorch_model_albert_base_191k.bin",
            from_tf=False,
            config=config)
        model.albert = albert_model

    elif 'babert-opensubtitle' == args['model_checkpoint']:
        # babert-opensubtitle
        # Prepare config & tokenizer
        vocab_path = "../embeddings/babert-opensubtitle/vocab.txt"
        config_path = "../embeddings/babert-opensubtitle/bert_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = BertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config)

        # Plug pretrained bert model
        bert_model = BertForPreTraining.from_pretrained(
            "../embeddings/babert-opensubtitle/model.ckpt-1000000.index",
            from_tf=True,
            config=config)
        model.bert = bert_model.bert

    elif 'babert-bpe-mlm-large-uncased-1100k' == args['model_checkpoint']:
        # babert_bpe
        # Prepare config & tokenizer
        vocab_path = "../embeddings/babert-bpe-mlm-large-uncased-1100k/pytorch_models_babert_uncased_large_1100k_vocab_uncased_30522.txt"
        config_path = "../embeddings/babert-bpe-mlm-large-uncased-1100k/pytorch_models_babert_uncased_large_1100k_bert_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = BertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config)

        # Plug pretrained bert model
        bert_model = BertForPreTraining.from_pretrained(
            "../embeddings/babert-bpe-mlm-large-uncased-1100k/pytorch_models_babert_uncased_large_1100k_pytorch_model_babert_large_1100k.bin",
            config=config)
        model.bert = bert_model.bert

    elif 'babert-bpe-mlm-large-uncased-1m' == args['model_checkpoint']:
        # babert_bpe
        # Prepare config & tokenizer
        vocab_path = "../embeddings/babert-bpe-mlm-large-uncased-1m/pytorch_models_babert_uncased_large_1mil_vocab_uncased_30522.txt"
        config_path = "../embeddings/babert-bpe-mlm-large-uncased-1m/pytorch_models_babert_uncased_large_1mil_bert_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = BertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config)

        # Plug pretrained bert model
        bert_model = BertForPreTraining.from_pretrained(
            "../embeddings/babert-bpe-mlm-large-uncased-1m/pytorch_models_babert_uncased_large_1mil_pytorch_model_babert_large_1mil.bin",
            config=config)
        model.bert = bert_model.bert

    elif 'babert-base-512' == args['model_checkpoint']:
        # babert_bpe
        # Prepare config & tokenizer
        vocab_path = "../embeddings/babert-base-512/pytorch_models_babert_base_512_vocab_uncased_30522.txt"
        config_path = "../embeddings/babert-base-512/pytorch_models_babert_base_512_bert_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = BertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config)

        # Plug pretrained bert model
        bert_model = BertForPreTraining.from_pretrained(
            "../embeddings/babert-base-512/pytorch_models_babert_base_512_pytorch_model_babert_base_uncased_512.bin",
            config=config)
        model.bert = bert_model.bert

    elif 'babert-bpe-mlm-large-uncased' == args['model_checkpoint']:
        # babert_bpe
        # Prepare config & tokenizer
        vocab_path = "../embeddings/babert-bpe-mlm-large-uncased/pytorch_models_babert_uncased_large_vocab_uncased_30522.txt"
        config_path = "../embeddings/babert-bpe-mlm-large-uncased/pytorch_models_babert_uncased_large_bert_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = BertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config)

        # Plug pretrained bert model
        bert_model = BertForPreTraining.from_pretrained(
            "../embeddings/babert-bpe-mlm-large-uncased/pytorch_models_babert_uncased_large_pytorch_model_babert_large_778500.bin",
            config=config)
        model.bert = bert_model.bert

    elif 'babert-bpe-mlm-uncased-128-dup10-5' == args['model_checkpoint']:
        # babert_bpe_wwmlm
        # Prepare config & tokenizer
        vocab_path = "../embeddings/babert-bpe-mlm-uncased-128-dup10-5/vocab.txt"
        config_path = "../embeddings/babert-bpe-mlm-uncased-128-dup10-5/bert_config.json"

        tokenizer = BertTokenizer(vocab_path)
        config = BertConfig.from_json_file(config_path)
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config)

        # Plug pretrained bert model
        bert_model = BertForPreTraining.from_pretrained(
            "../embeddings/babert-bpe-mlm-uncased-128-dup10-5/pytorch_model.bin",
            config=config)
        model.bert = bert_model.bert

    elif 'bert-base-multilingual' in args['model_checkpoint']:
        # bert-base-multilingual-uncased or bert-base-multilingual-cased
        # Prepare config & tokenizer
        vocab_path, config_path = None, None
        tokenizer = BertTokenizer.from_pretrained(args['model_checkpoint'])
        config = BertConfig.from_pretrained(args['model_checkpoint'])
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification.from_pretrained(
                args['model_checkpoint'], config=config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification.from_pretrained(
                args['model_checkpoint'], config=config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification.from_pretrained(
                args['model_checkpoint'], config=config)

    elif 'xlm-mlm' in args['model_checkpoint']:
        # xlm-mlm-100-1280
        # Prepare config & tokenizer
        vocab_path, config_path = None, None
        tokenizer = XLMTokenizer.from_pretrained(args['model_checkpoint'])
        config = XLMConfig.from_pretrained(args['model_checkpoint'])
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = XLMForSequenceClassification.from_pretrained(
                args['model_checkpoint'], config=config)
        elif 'token_classification' == args['task']:
            model = XLMForWordClassification.from_pretrained(
                args['model_checkpoint'], config=config)
        elif 'multi_label_classification' == args['task']:
            model = XLMForMultiLabelClassification.from_pretrained(
                args['model_checkpoint'], config=config)
    elif 'xlm-roberta' in args['model_checkpoint']:
        # xlm-roberta-base or xlm-roberta-large
        # Prepare config & tokenizer
        vocab_path, config_path = None, None
        tokenizer = XLMRobertaTokenizer.from_pretrained(
            args['model_checkpoint'])
        config = XLMRobertaConfig.from_pretrained(args['model_checkpoint'])
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = XLMRobertaForSequenceClassification.from_pretrained(
                args['model_checkpoint'], config=config)
        elif 'token_classification' == args['task']:
            model = XLMRobertaForWordClassification.from_pretrained(
                args['model_checkpoint'], config=config)
        elif 'multi_label_classification' == args['task']:
            model = XLMRobertaForMultiLabelClassification.from_pretrained(
                args['model_checkpoint'], config=config)
    elif 'word2vec' in args['model_checkpoint'] or 'fasttext' in args[
            'model_checkpoint']:
        # Prepare config & tokenizer
        vocab_path = args['vocab_path']
        config_path = None

        word_tokenizer = args['word_tokenizer_class']()
        emb_path = args['embedding_path'][args['model_checkpoint']]

        _, vocab_map = load_vocab(vocab_path)
        tokenizer = SimpleTokenizer(vocab_map,
                                    word_tokenizer,
                                    lower=args["lower"])
        vocab_list = list(tokenizer.vocab.keys())

        config = BertConfig.from_pretrained('bert-base-uncased')
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']
        config.num_hidden_layers = args["num_layers"]

        if args['model_checkpoint'] == 'word2vec-twitter':
            embeddings = gen_embeddings(vocab_list, emb_path)
            config.hidden_size = 400
            config.num_attention_heads = 8

        if args['model_checkpoint'] == 'fasttext-cc-id' or args[
                'model_checkpoint'] == 'fasttext-cc-id-300-no-oov-uncased' or args[
                    'model_checkpoint'] == 'fasttext-4B-id-300-no-oov-uncased':
            embeddings = gen_embeddings(vocab_list, emb_path, emb_dim=300)
            config.hidden_size = 300
            config.num_attention_heads = 10

        config.vocab_size = len(embeddings)

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config)
            model.bert.embeddings.word_embeddings.weight.data.copy_(
                torch.FloatTensor(embeddings))
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config)
            model.bert.embeddings.word_embeddings.weight.data.copy_(
                torch.FloatTensor(embeddings))
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config)
            model.bert.embeddings.word_embeddings.weight.data.copy_(
                torch.FloatTensor(embeddings))

    elif 'scratch' in args['model_checkpoint']:
        vocab_path, config_path = None, None

        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        config = BertConfig.from_pretrained("bert-base-uncased")
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']
        config.num_hidden_layers = args["num_layers"]
        config.hidden_size = 300
        config.num_attention_heads = 10

        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification(config=config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification(config=config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification(config=config)
    elif 'indobenchmark' in args['model_checkpoint']:
        # indobenchmark models
        # Prepare config & tokenizer
        vocab_path, config_path = None, None
        tokenizer = BertTokenizer.from_pretrained(args['model_checkpoint'])
        config = BertConfig.from_pretrained(args['model_checkpoint'])
        if type(args['num_labels']) == list:
            config.num_labels = max(args['num_labels'])
            config.num_labels_list = args['num_labels']
        else:
            config.num_labels = args['num_labels']

        # Instantiate model
        if 'sequence_classification' == args['task']:
            model = BertForSequenceClassification.from_pretrained(
                args['model_checkpoint'], config=config)
        elif 'token_classification' == args['task']:
            model = BertForWordClassification.from_pretrained(
                args['model_checkpoint'], config=config)
        elif 'multi_label_classification' == args['task']:
            model = BertForMultiLabelClassification.from_pretrained(
                args['model_checkpoint'], config=config)

    return model, tokenizer, vocab_path, config_path
Exemple #15
0
def main():
    #解析参数
    config.parse()
    args = config.args
    for k,v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #准备任务
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Student的配置
    if args.model_architecture == "electra":
        # 从transformers包中导入ElectraConfig, 并加载配置
        bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
        # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
        bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
        bert_config_S.output_attentions = (args.output_attention_layers=='true')
        # num_labels;类别个数
        bert_config_S.num_labels = num_labels
        assert args.max_seq_length <= bert_config_S.max_position_embeddings
    elif args.model_architecture == "albert":
        # 从transformers包中导入AlbertConfig, 并加载配置
        bert_config_S = AlbertConfig.from_json_file(args.bert_config_file_S)
        # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
        bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
        bert_config_S.output_attentions = (args.output_attention_layers=='true')
        # num_labels;类别个数
        bert_config_S.num_labels = num_labels
        assert args.max_seq_length <= bert_config_S.max_position_embeddings
    else:
        bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
        assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_dataset = None
    eval_datasets  = None
    num_train_steps = None
    # electra和bert都使用的bert的 tokenizer方式
    tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
    # 加载数据集, 计算steps
    if args.do_train:
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset])
        num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs
        logger.info("训练数据集已加载")
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
        for eval_task in eval_task_names:
            eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True))
        logger.info("预测数据集已加载")


    # Student的配置
    if args.model_architecture == "electra":
        #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型
        model_S = ElectraSPC(bert_config_S)
    elif args.model_architecture == "albert":
        model_S = AlbertSPC(bert_config_S)
    else:
        #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型
        model_S = BertSPCSimple(bert_config_S, num_labels=num_labels,args=args)
    #对加载后的student模型的参数进行初始化, 使用student模型预测
    if args.load_model_type=='bert' and args.model_architecture not in ["electra", "albert"]:
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.embeddings')}
            missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
            missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
            print(f"missing_keys,注意丢失的参数{missing_keys}")
        logger.info("Model loaded")
    elif args.load_model_type=='all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    elif args.model_architecture in ["electra", "albert"]:
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    else:
        logger.info("Model is randomly initialized.")
    #模型move to device
    model_S.to(device)
    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)

    if args.do_train:
        #parameters, params是模型的所有参数组成的列表
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("要训练的模型参数量组是,包括decay_group和no_decay_group: %d", len(all_trainable_params))
        # 优化器设置
        optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
                             warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
                             s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)

        logger.info("***** 开始训练 *****")
        logger.info("  样本数是 = %d", len(train_dataset))
        logger.info("  前向 batch size = %d", forward_batch_size)
        logger.info("  训练的steps = %d", num_train_steps)

        ########### 训练的配置 ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps = args.gradient_accumulation_steps,
            ckpt_frequency = args.ckpt_frequency,
            log_dir = args.output_dir,
            output_dir = args.output_dir,
            device = args.device)

        #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型
        distiller = BasicTrainer(train_config = train_config,
                                 model = model_S,
                                 adaptor = BertForGLUESimpleAdaptorTraining)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        #训练的dataloader
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
        #执行callbakc函数,对eval数据集
        callback_func = partial(predict, eval_datasets=eval_datasets, args=args)
        with distiller:
            #开始训练
            distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
                              num_epochs=args.num_train_epochs, callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,eval_datasets,step=0,args=args)
        print(res)
    args = parser.parse_args()

    label_dict = {}
    with open(args.label, 'r') as f:
        line_id = 0
        while True:
            line = f.readline()
            if line == '': break
            line = line.rstrip()
            label_dict[line] = line_id
            line_id += 1

    if 'albert' in args.model:
        model_type = 'albert'
        tokenizer = AlbertTokenizer(vocab_file=args.tokenizer)
        config = AlbertConfig.from_json_file(args.config)
        model = AlbertModel.from_pretrained(pretrained_model_name_or_path=None,
                                            config=config,
                                            state_dict=torch.load(args.model))
    elif 'bert' in args.model:
        model_type = 'bert'
        tokenizer = BertTokenizer(vocab_file=args.tokenizer)
        config = BertConfig.from_json_file(args.config)
        model = BertModel.from_pretrained(pretrained_model_name_or_path=None,
                                          config=config,
                                          state_dict=torch.load(args.model))
    elif 'electra' in args.model:
        model_type = 'electra'
        tokenizer = ElectraTokenizer(vocab_file=args.tokenizer)
        config = ElectraConfig.from_json_file(args.config)
        model = ElectraModel.from_pretrained(