def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file,
                                     pytorch_dump_path,
                                     discriminator_or_generator):
    # Initialise PyTorch model
    config = ElectraConfig.from_json_file(config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))

    if discriminator_or_generator == "discriminator":
        model = ElectraForPreTraining(config)
    elif discriminator_or_generator == "generator":
        model = ElectraForMaskedLM(config)
    else:
        raise ValueError(
            "The discriminator_or_generator argument should be either 'discriminator' or 'generator'"
        )

    # Load weights from tf checkpoint
    load_tf_weights_in_electra(
        model,
        config,
        tf_checkpoint_path,
        discriminator_or_generator=discriminator_or_generator)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
Example #2
0
    def load_electra_model(self):
        parser = argparse.ArgumentParser()
        args = parser.parse_args()
        args.output_encoded_layers = True
        args.output_attention_layers = True
        args.output_att_score = True
        args.output_att_sum = True
        self.args = args
        # 解析配置文件, 教师模型和student模型的vocab是不变的
        # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型
        # student config:  config/chinese_bert_config_L4t.json
        # distil student model:  distil_model/gs8316.pkl
        bert_config_file_S = self.model_conf
        tuned_checkpoint_S = self.model_file
        # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
        bert_config_S = ElectraConfig.from_json_file(bert_config_file_S)
        bert_config_S.num_labels = self.num_labels

        # 加载tokenizer
        self.predict_tokenizer = BertTokenizer(vocab_file=self.vocab_file)

        # 加载模型
        self.predict_model = ElectraSPC(bert_config_S)
        assert os.path.exists(tuned_checkpoint_S), "模型文件不存在,请检查"
        state_dict_S = torch.load(tuned_checkpoint_S, map_location=self.device)
        self.predict_model.load_state_dict(state_dict_S)
        if self.verbose:
            print("模型已加载")
        logger.info(f"预测模型{tuned_checkpoint_S}加载完成")
Example #3
0
def main():
    #parse arguments
    config.parse()
    args = config.args

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -  %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger = logging.getLogger("Main")
    #arguments check
    device, n_gpu = args_check(logger, args)
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
    if args.local_rank != -1:
        logger.warning(
            f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}"
        )

    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
    bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
    bert_config_S.num_labels = len(label2id_dict)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_examples = None
    train_dataset = None
    eval_examples = None
    eval_dataset = None
    num_train_steps = None

    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    if args.do_train:
        train_examples, train_dataset = read_features(
            args.train_file,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)
    if args.do_predict:
        eval_examples, eval_dataset = read_features(
            args.predict_file,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)

    if args.local_rank == 0:
        torch.distributed.barrier()
    #Build Model and load checkpoint
    model_S = ElectraForTokenClassification(bert_config_S)
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        #state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
        #missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,
                                                                strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    model_S.to(device)

    if args.do_train:
        #parameters
        if args.lr_decay is not None:
            outputs_params = list(model_S.classifier.named_parameters())
            outputs_params = divide_parameters(outputs_params,
                                               lr=args.learning_rate)

            electra_params = []
            n_layers = len(model_S.electra.encoder.layer)
            assert n_layers == 12
            for i, n in enumerate(reversed(range(n_layers))):
                encoder_params = list(
                    model_S.electra.encoder.layer[n].named_parameters())
                lr = args.learning_rate * args.lr_decay**(i + 1)
                electra_params += divide_parameters(encoder_params, lr=lr)
                logger.info(f"{i},{n},{lr}")
            embed_params = [
                (name, value)
                for name, value in model_S.electra.named_parameters()
                if 'embedding' in name
            ]
            logger.info(f"{[name for name,value in embed_params]}")
            lr = args.learning_rate * args.lr_decay**(n_layers + 1)
            electra_params += divide_parameters(embed_params, lr=lr)
            logger.info(f"embed lr:{lr}")
            all_trainable_params = outputs_params + electra_params
            assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\
                (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters())))
        else:
            params = list(model_S.named_parameters())
            all_trainable_params = divide_parameters(params,
                                                     lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)

        num_train_steps = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        optimizer = AdamW(all_trainable_params,
                          lr=args.learning_rate,
                          correct_bias=False)
        if args.official_schedule == 'const':
            scheduler_class = get_constant_schedule_with_warmup
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps)
            }
            #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps))
        elif args.official_schedule == 'linear':
            scheduler_class = get_linear_schedule_with_warmup
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps),
                'num_training_steps': num_train_steps
            }
            #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps)
        elif args.official_schedule == 'const_nowarmup':
            scheduler_class = get_constant_schedule
            scheduler_args = {}
        else:
            raise NotImplementedError

        logger.warning("***** Running training *****")
        logger.warning("local_rank %d Num orig examples = %d", args.local_rank,
                       len(train_examples))
        logger.warning("local_rank %d Num split examples = %d",
                       args.local_rank, len(train_dataset))
        logger.warning("local_rank %d Forward batch size = %d",
                       args.local_rank, forward_batch_size)
        logger.warning("local_rank %d Num backward steps = %d",
                       args.local_rank, num_train_steps)

        ########### TRAINING ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device,
            fp16=args.fp16,
            local_rank=args.local_rank)
        logger.info(f"{train_config}")

        distiller = BasicTrainer(
            train_config=train_config,
            model=model_S,
            adaptor=ElectraForTokenClassificationAdaptorTraining)

        # evluate the model in a single process in ddp_predict
        callback_func = partial(ddp_predict,
                                eval_examples=eval_examples,
                                eval_dataset=eval_dataset,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler_class=scheduler_class,
                            scheduler_args=scheduler_args,
                            max_grad_norm=1.0,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = ddp_predict(model_S,
                          eval_examples,
                          eval_dataset,
                          step=0,
                          args=args)
        print(res)
Example #4
0
def main():
    #解析参数
    config.parse()
    args = config.args
    for k,v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #准备任务
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Student的配置
    if args.model_architecture == "electra":
        # 从transformers包中导入ElectraConfig, 并加载配置
        bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
        # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
        bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
        bert_config_S.output_attentions = (args.output_attention_layers=='true')
        # num_labels;类别个数
        bert_config_S.num_labels = num_labels
        assert args.max_seq_length <= bert_config_S.max_position_embeddings
    elif args.model_architecture == "albert":
        # 从transformers包中导入AlbertConfig, 并加载配置
        bert_config_S = AlbertConfig.from_json_file(args.bert_config_file_S)
        # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
        bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
        bert_config_S.output_attentions = (args.output_attention_layers=='true')
        # num_labels;类别个数
        bert_config_S.num_labels = num_labels
        assert args.max_seq_length <= bert_config_S.max_position_embeddings
    else:
        bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
        assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_dataset = None
    eval_datasets  = None
    num_train_steps = None
    # electra和bert都使用的bert的 tokenizer方式
    tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
    # 加载数据集, 计算steps
    if args.do_train:
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset])
        num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs
        logger.info("训练数据集已加载")
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
        for eval_task in eval_task_names:
            eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True))
        logger.info("预测数据集已加载")


    # Student的配置
    if args.model_architecture == "electra":
        #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型
        model_S = ElectraSPC(bert_config_S)
    elif args.model_architecture == "albert":
        model_S = AlbertSPC(bert_config_S)
    else:
        #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型
        model_S = BertSPCSimple(bert_config_S, num_labels=num_labels,args=args)
    #对加载后的student模型的参数进行初始化, 使用student模型预测
    if args.load_model_type=='bert' and args.model_architecture not in ["electra", "albert"]:
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.embeddings')}
            missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
            missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
            print(f"missing_keys,注意丢失的参数{missing_keys}")
        logger.info("Model loaded")
    elif args.load_model_type=='all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    elif args.model_architecture in ["electra", "albert"]:
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    else:
        logger.info("Model is randomly initialized.")
    #模型move to device
    model_S.to(device)
    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)

    if args.do_train:
        #parameters, params是模型的所有参数组成的列表
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("要训练的模型参数量组是,包括decay_group和no_decay_group: %d", len(all_trainable_params))
        # 优化器设置
        optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
                             warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
                             s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)

        logger.info("***** 开始训练 *****")
        logger.info("  样本数是 = %d", len(train_dataset))
        logger.info("  前向 batch size = %d", forward_batch_size)
        logger.info("  训练的steps = %d", num_train_steps)

        ########### 训练的配置 ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps = args.gradient_accumulation_steps,
            ckpt_frequency = args.ckpt_frequency,
            log_dir = args.output_dir,
            output_dir = args.output_dir,
            device = args.device)

        #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型
        distiller = BasicTrainer(train_config = train_config,
                                 model = model_S,
                                 adaptor = BertForGLUESimpleAdaptorTraining)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        #训练的dataloader
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
        #执行callbakc函数,对eval数据集
        callback_func = partial(predict, eval_datasets=eval_datasets, args=args)
        with distiller:
            #开始训练
            distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
                              num_epochs=args.num_train_epochs, callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,eval_datasets,step=0,args=args)
        print(res)
Example #5
0
    logging.info('Init pre-train model...')
    if params.pre_model_type == 'NEZHA':
        bert_config = NEZHAConfig.from_json_file(
            os.path.join(params.bert_model_dir, 'bert_config.json'))
        model = BertForTokenClassification(config=bert_config, params=params)
        # NEZHA init
        torch_init_model(
            model, os.path.join(params.bert_model_dir, 'pytorch_model.bin'))
    elif params.pre_model_type == 'RoBERTa':
        bert_config = BertConfig.from_json_file(
            os.path.join(params.bert_model_dir, 'bert_config.json'))
        model = BertForTokenClassification.from_pretrained(
            config=bert_config,
            pretrained_model_name_or_path=params.bert_model_dir,
            params=params)
    elif params.pre_model_type == 'ELECTRA':
        bert_config = ElectraConfig.from_json_file(
            os.path.join(params.bert_model_dir, 'bert_config.json'))
        model = ElectraForTokenClassification.from_pretrained(
            config=bert_config,
            pretrained_model_name_or_path=params.bert_model_dir,
            params=params)
    else:
        raise ValueError(
            'Pre-train Model type must be NEZHA or ELECTRA or RoBERTa!')
    logging.info('-done')

    # Train and evaluate the model
    logging.info("Starting training for {} epoch(s)".format(args.epoch_num))
    train_and_evaluate(model, params, args.restore_file)
        tokenizer = AlbertTokenizer(vocab_file=args.tokenizer)
        config = AlbertConfig.from_json_file(args.config)
        model = AlbertModel.from_pretrained(pretrained_model_name_or_path=None,
                                            config=config,
                                            state_dict=torch.load(args.model))
    elif 'bert' in args.model:
        model_type = 'bert'
        tokenizer = BertTokenizer(vocab_file=args.tokenizer)
        config = BertConfig.from_json_file(args.config)
        model = BertModel.from_pretrained(pretrained_model_name_or_path=None,
                                          config=config,
                                          state_dict=torch.load(args.model))
    elif 'electra' in args.model:
        model_type = 'electra'
        tokenizer = ElectraTokenizer(vocab_file=args.tokenizer)
        config = ElectraConfig.from_json_file(args.config)
        model = ElectraModel.from_pretrained(
            pretrained_model_name_or_path=None,
            config=config,
            state_dict=torch.load(args.model))
    else:
        raise NotImplementedError("The model is currently not supported")

    def process_line(line):
        data = json.loads(line)
        tokens = data['text'].split(' ')
        labels = data['targets']
        return tokens, labels

    def retokenize(tokens_labels):
        tokens, labels = tokens_labels
Example #7
0
def main():
    #parse arguments
    config.parse()
    args = config.args

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -  %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger = logging.getLogger("Main")
    #arguments check
    device, n_gpu = args_check(logger, args)
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
    if args.local_rank != -1:
        logger.warning(
            f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}"
        )

    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #从transformers包中导入ElectraConfig, 并加载配置
    bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
    # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
    bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
    #num_labels;类别个数
    bert_config_S.num_labels = len(label2id_dict)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #读取数据
    train_examples = None
    train_dataset = None
    eval_examples = None
    eval_dataset = None
    num_train_steps = None
    # 加载Tokenizer
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    if args.do_train:
        # 返回所有的样本和样本形成的dataset格式, dataset的格式[all_token_ids,all_input_mask,all_label_ids]
        print("开始加载训练集数据")
        train_examples, train_dataset = read_features(
            args.train_file,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)
    if args.do_predict:
        print("开始加载测试集数据")
        eval_examples, eval_dataset = read_features(
            args.predict_file,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)

    if args.local_rank == 0:
        torch.distributed.barrier()
    #Build Model and load checkpoint
    model_S = ElectraForTokenClassification(bert_config_S)
    #加载student模型的参数, 默认是bert类型
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        #state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
        #missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,
                                                                strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    #模型放到device
    model_S.to(device)

    if args.do_train:
        #parameters
        if args.lr_decay is not None:
            # 分类器层的参数 weight, bias
            outputs_params = list(model_S.classifier.named_parameters())
            #拆分出做学习率衰减的参数和不衰减的参数
            outputs_params = divide_parameters(outputs_params,
                                               lr=args.learning_rate)

            electra_params = []
            # eg: 12, encoder层共12层
            n_layers = len(model_S.electra.encoder.layer)
            assert n_layers == 12
            for i, n in enumerate(reversed(range(n_layers))):
                encoder_params = list(
                    model_S.electra.encoder.layer[n].named_parameters())
                lr = args.learning_rate * args.lr_decay**(i + 1)
                electra_params += divide_parameters(encoder_params, lr=lr)
                logger.info(f"{i}:第{n}层的学习率是:{lr}")
            embed_params = [
                (name, value)
                for name, value in model_S.electra.named_parameters()
                if 'embedding' in name
            ]
            logger.info(f"{[name for name,value in embed_params]}")
            lr = args.learning_rate * args.lr_decay**(n_layers + 1)
            electra_params += divide_parameters(embed_params, lr=lr)
            logger.info(f"embedding层的学习率 lr:{lr}")
            all_trainable_params = outputs_params + electra_params
            assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\
                (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters())))
        else:
            params = list(model_S.named_parameters())
            all_trainable_params = divide_parameters(params,
                                                     lr=args.learning_rate)
        logger.info("可训练的参数all_trainable_params共有: %d",
                    len(all_trainable_params))

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            train_sampler = DistributedSampler(train_dataset)
        #生成dataloader
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        # 根据epoch计算出运行多少steps
        num_train_steps = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        optimizer = AdamW(all_trainable_params,
                          lr=args.learning_rate,
                          correct_bias=False)
        if args.official_schedule == 'const':
            # 常数学习率
            scheduler_class = get_constant_schedule_with_warmup
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps)
            }
            #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps))
        elif args.official_schedule == 'linear':
            # 线性学习率
            scheduler_class = get_linear_schedule_with_warmup
            # warmup多少步,10%的step是warmup的
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps),
                'num_training_steps': num_train_steps
            }
            #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps)
        elif args.official_schedule == 'const_nowarmup':
            scheduler_class = get_constant_schedule
            scheduler_args = {}
        else:
            raise NotImplementedError

        logger.warning("***** 开始 训练 *****")
        logger.warning("local_rank %d 原样本数 = %d", args.local_rank,
                       len(train_examples))
        logger.warning("local_rank %d split之后的样本数 = %d", args.local_rank,
                       len(train_dataset))
        logger.warning("local_rank %d 前向 batch size = %d", args.local_rank,
                       forward_batch_size)
        logger.warning("local_rank %d 训练的steps = %d", args.local_rank,
                       num_train_steps)

        ########### TRAINING ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,  #保存频率
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device,
            fp16=args.fp16,
            local_rank=args.local_rank)

        logger.info(f"训练的配置文件:")
        logger.info(f"{train_config}")
        # 初始化训练器
        distiller = BasicTrainer(
            train_config=train_config,
            model=model_S,
            adaptor=ElectraForTokenClassificationAdaptorTraining)

        # 初始化callback函数,使用的ddp_predict函数进行评估
        callback_func = partial(ddp_predict,
                                eval_examples=eval_examples,
                                eval_dataset=eval_dataset,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler_class=scheduler_class,
                            scheduler_args=scheduler_args,
                            max_grad_norm=1.0,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = ddp_predict(model_S,
                          eval_examples,
                          eval_dataset,
                          step=0,
                          args=args)
        print(res)