Пример #1
0
    def do_train(self, data, truncated=False):
        """
        训练模型, 数据集分成2部分,训练集和验证集, 默认比例9:1
        :param data: 输入的数据,注意如果做truncated,那么输入的数据为 [(content,aspect,start_idx, end_idx, label),...,]
        :param truncated: 是否要截断,截断按照 self.left_max_seq_len, self.right_max_seq_len进行
        :return:
        """
        if truncated:
            data, locations = self.do_truncate_data(data)
        train_data_len = int(len(data) * 0.9)
        train_data = data[:train_data_len]
        eval_data = data[train_data_len:]
        train_dataset = load_examples(train_data, self.max_seq_length, self.train_tokenizer, self.label_list, self.reverse_truncate)
        eval_dataset = load_examples(eval_data, self.max_seq_length, self.train_tokenizer, self.label_list, self.reverse_truncate)
        logger.info("训练数据集已加载,开始训练")
        num_train_steps = int(len(train_dataset) / self.train_batch_size) * self.num_train_epochs
        forward_batch_size = int(self.train_batch_size / self.gradient_accumulation_steps)
        # 开始训练
        params = list(self.train_model.named_parameters())
        all_trainable_params = divide_parameters(params, lr=self.learning_rate, weight_decay_rate=self.weight_decay_rate)
        # 优化器设置
        optimizer = BERTAdam(all_trainable_params, lr=self.learning_rate,
                             warmup=self.warmup_proportion, t_total=num_train_steps, schedule=self.schedule,
                             s_opt1=self.s_opt1, s_opt2=self.s_opt2, s_opt3=self.s_opt3)

        logger.info("***** 开始训练 *****")
        logger.info("  训练样本数是 = %d", len(train_dataset))
        logger.info("  评估样本数是 = %d", len(eval_dataset))
        logger.info("  前向 batch size = %d", forward_batch_size)
        logger.info("  训练的steps = %d", num_train_steps)
        ########### 训练的配置 ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps = self.gradient_accumulation_steps,
            ckpt_frequency = self.ckpt_frequency,
            log_dir = self.output_dir,
            output_dir = self.output_dir,
            device = self.device)
        #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型
        distiller = BasicTrainer(train_config = train_config,
                                 model = self.train_model,
                                 adaptor = BertForGLUESimpleAdaptorTraining)

        train_sampler = RandomSampler(train_dataset)
        #训练的dataloader
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=forward_batch_size, drop_last=True)
        #执行callbakc函数,对eval数据集
        callback_func = partial(self.do_predict, eval_dataset=eval_dataset)
        with distiller:
            #开始训练
            distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
                              num_epochs=self.num_train_epochs, callback=callback_func)
        logger.info(f"训练完成")
        return "Done"
Пример #2
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets = None
    num_train_steps = None
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)
    # 加载数据集, 计算steps
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                args.task_name,
                                                tokenizer,
                                                evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args,
                                                        args.aux_task_name,
                                                        tokenizer,
                                                        evaluate=False,
                                                        is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset(
                [train_dataset, aux_train_dataset])
        num_train_steps = int(
            len(train_dataset) / args.train_batch_size) * args.num_train_epochs
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli",
                           "mnli-mm") if args.task_name == "mnli" else (
                               args.task_name, )
        for eval_task in eval_task_names:
            eval_datasets.append(
                load_and_cache_examples(args,
                                        eval_task,
                                        tokenizer,
                                        evaluate=True))
    logger.info("数据集已加载")

    #加载模型并初始化, 只用student模型,其实这里相当于在MNLI数据上训练教师模型,只训练一个模型
    model_S = BertForGLUESimple(bert_config_S,
                                num_labels=num_labels,
                                args=args)
    #初始化student模型
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items()
                if k.startswith('bert.embeddings')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items() if k.startswith('bert.')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            assert len(missing_keys) == 0
        logger.info("Model loaded")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    else:
        logger.info("Model is randomly initialized.")
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))
        # 优化器设置
        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### 蒸馏 ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)

        #执行监督训练,而不是蒸馏。它可以用于训练teacher模型。初始化模型
        distiller = BasicTrainer(train_config=train_config,
                                 model=model_S,
                                 adaptor=BertForGLUESimpleAdaptorTraining)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        callback_func = partial(predict,
                                eval_datasets=eval_datasets,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S, eval_datasets, step=0, args=args)
        print(res)
Пример #3
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_T = BertConfig.from_json_file(args.bert_config_file_T)
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_T.max_position_embeddings
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets = None
    num_train_steps = None
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)
    # 加载数据集
    if args.do_train:
        train_dataset, examples = load_and_cache_examples(args,
                                                          args.task_name,
                                                          tokenizer,
                                                          evaluate=False)
        if args.aux_task_name:
            aux_train_dataset, examples = load_and_cache_examples(
                args,
                args.aux_task_name,
                tokenizer,
                evaluate=False,
                is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset(
                [train_dataset, aux_train_dataset])
        num_train_steps = int(
            len(train_dataset) / args.train_batch_size) * args.num_train_epochs
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli",
                           "mnli-mm") if args.task_name == "mnli" else (
                               args.task_name, )
        for eval_task in eval_task_names:
            eval_dataset, examples = load_and_cache_examples(args,
                                                             eval_task,
                                                             tokenizer,
                                                             evaluate=True)
            eval_datasets.append(eval_dataset)
    logger.info("数据集加载成功")

    #加载模型,加载teacher和student模型
    model_T = BertForGLUESimple(bert_config_T,
                                num_labels=num_labels,
                                args=args)
    model_S = BertForGLUESimple(bert_config_S,
                                num_labels=num_labels,
                                args=args)
    #加载teacher模型参数
    if args.tuned_checkpoint_T is not None:
        state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu')
        model_T.load_state_dict(state_dict_T)
        model_T.eval()
    else:
        assert args.do_predict is True
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items()
                if k.startswith('bert.embeddings')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items() if k.startswith('bert.')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            assert len(missing_keys) == 0
        logger.info("Model loaded")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    else:
        logger.info("Student模型没有可加载参数,随机初始化参数 randomly initialized.")
    model_T.to(device)
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_T = torch.nn.DataParallel(model_T)  #,output_device=n_gpu-1)
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))
        #优化器配置
        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)
        # 定义了一些固定的matches配置文件
        from matches import matches
        intermediate_matches = None
        if isinstance(args.matches, (list, tuple)):
            intermediate_matches = []
            for match in args.matches:
                intermediate_matches += matches[match]
        logger.info(f"中间层match信息: {intermediate_matches}")
        distill_config = DistillationConfig(
            temperature=args.temperature,
            intermediate_matches=intermediate_matches)

        logger.info(f"训练配置: {train_config}")
        logger.info(f"蒸馏配置: {distill_config}")
        adaptor_T = partial(BertForGLUESimpleAdaptor,
                            no_logits=args.no_logits,
                            no_mask=args.no_inputs_mask)
        adaptor_S = partial(BertForGLUESimpleAdaptor,
                            no_logits=args.no_logits,
                            no_mask=args.no_inputs_mask)
        # 支持中间状态匹配的通用蒸馏模型
        distiller = GeneralDistiller(train_config=train_config,
                                     distill_config=distill_config,
                                     model_T=model_T,
                                     model_S=model_S,
                                     adaptor_T=adaptor_T,
                                     adaptor_S=adaptor_S)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        callback_func = partial(predict,
                                eval_datasets=eval_datasets,
                                args=args,
                                examples=examples)
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,
                      eval_datasets,
                      step=0,
                      args=args,
                      examples=examples,
                      label_list=label_list)
        print(res)
Пример #4
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_examples = None
    train_features = None
    eval_examples = None
    eval_features = None
    num_train_steps = None

    tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file,
                                     do_lower_case=args.do_lower_case)
    convert_fn = partial(convert_examples_to_features,
                         tokenizer=tokenizer,
                         max_seq_length=args.max_seq_length,
                         doc_stride=args.doc_stride,
                         max_query_length=args.max_query_length)
    if args.do_train:
        train_examples, train_features = read_and_convert(
            args.train_file,
            is_training=True,
            do_lower_case=args.do_lower_case,
            read_fn=read_squad_examples,
            convert_fn=convert_fn)
        if args.fake_file_1:
            fake_examples1, fake_features1 = read_and_convert(
                args.fake_file_1,
                is_training=True,
                do_lower_case=args.do_lower_case,
                read_fn=read_squad_examples,
                convert_fn=convert_fn)
            train_examples += fake_examples1
            train_features += fake_features1
        if args.fake_file_2:
            fake_examples2, fake_features2 = read_and_convert(
                args.fake_file_2,
                is_training=True,
                do_lower_case=args.do_lower_case,
                read_fn=read_squad_examples,
                convert_fn=convert_fn)
            train_examples += fake_examples2
            train_features += fake_features2

        num_train_steps = int(len(train_features) /
                              args.train_batch_size) * args.num_train_epochs

    if args.do_predict:
        eval_examples, eval_features = read_and_convert(
            args.predict_file,
            is_training=False,
            do_lower_case=args.do_lower_case,
            read_fn=read_squad_examples,
            convert_fn=convert_fn)

    #Build Model and load checkpoint
    model_S = BertForQASimple(bert_config_S, args)
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        state_weight = {
            k[5:]: v
            for k, v in state_dict_S.items() if k.startswith('bert.')
        }
        missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                       strict=False)
        assert len(missing_keys) == 0
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))

        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)

        distiller = BasicTrainer(train_config=train_config,
                                 model=model_S,
                                 adaptor=BertForQASimpleAdaptorTraining)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_doc_mask = torch.tensor([f.doc_mask for f in train_features],
                                    dtype=torch.float)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_start_positions = torch.tensor(
            [f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor(
            [f.end_position for f in train_features], dtype=torch.long)

        train_dataset = TensorDataset(all_input_ids, all_segment_ids,
                                      all_input_mask, all_doc_mask,
                                      all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        callback_func = partial(predict,
                                eval_examples=eval_examples,
                                eval_features=eval_features,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S, eval_examples, eval_features, step=0, args=args)
        print(res)
Пример #5
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load config
    teachers_and_student = parse_model_config(args.model_config_json)

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets = None
    num_train_steps = None

    tokenizer_S = teachers_and_student['student']['tokenizer']
    prefix_S = teachers_and_student['student']['prefix']

    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                args.task_name,
                                                tokenizer_S,
                                                prefix=prefix_S,
                                                evaluate=False)
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli",
                           "mnli-mm") if args.task_name == "mnli" else (
                               args.task_name, )
        for eval_task in eval_task_names:
            eval_datasets.append(
                load_and_cache_examples(args,
                                        eval_task,
                                        tokenizer_S,
                                        prefix=prefix_S,
                                        evaluate=True))
    logger.info("Data loaded")

    #Build Model and load checkpoint
    if args.do_train:
        model_Ts = []
        for teacher in teachers_and_student['teachers']:
            model_type_T = teacher['model_type']
            model_config_T = teacher['config']
            checkpoint_T = teacher['checkpoint']

            _, _, model_class_T = MODEL_CLASSES[model_type_T]
            model_T = model_class_T(model_config_T, num_labels=num_labels)
            state_dict_T = torch.load(checkpoint_T, map_location='cpu')
            missing_keys, un_keys = model_T.load_state_dict(state_dict_T,
                                                            strict=True)
            logger.info(f"Teacher Model {model_type_T} loaded")
            model_T.to(device)
            model_T.eval()
            model_Ts.append(model_T)

    student = teachers_and_student['student']
    model_type_S = student['model_type']
    model_config_S = student['config']
    checkpoint_S = student['checkpoint']
    _, _, model_class_S = MODEL_CLASSES[model_type_S]
    model_S = model_class_S(model_config_S, num_labels=num_labels)
    if checkpoint_S is not None:
        state_dict_S = torch.load(checkpoint_S, map_location='cpu')
        missing_keys, un_keys = model_S.load_state_dict(state_dict_S,
                                                        strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{un_keys}")
    else:
        logger.warning("Initializing student randomly")
    logger.info("Student Model loaded")
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            if args.do_train:
                model_Ts = [
                    torch.nn.DataParallel(model_T) for model_T in model_Ts
                ]
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        num_train_steps = int(
            len(train_dataloader) // args.gradient_accumulation_steps *
            args.num_train_epochs)

        ########## DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            fp16=args.fp16,
            device=args.device)

        distill_config = DistillationConfig(temperature=args.temperature,
                                            kd_loss_type='ce')

        logger.info(f"{train_config}")
        logger.info(f"{distill_config}")
        adaptor_T = BertForGLUESimpleAdaptor
        adaptor_S = BertForGLUESimpleAdaptor

        distiller = MultiTeacherDistiller(train_config=train_config,
                                          distill_config=distill_config,
                                          model_T=model_Ts,
                                          model_S=model_S,
                                          adaptor_T=adaptor_T,
                                          adaptor_S=adaptor_S)

        optimizer = AdamW(all_trainable_params, lr=args.learning_rate)
        scheduler_class = get_linear_schedule_with_warmup
        scheduler_args = {
            'num_warmup_steps': int(args.warmup_proportion * num_train_steps),
            'num_training_steps': num_train_steps
        }

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        callback_func = partial(predict,
                                eval_datasets=eval_datasets,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler_class=scheduler_class,
                            scheduler_args=scheduler_args,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func,
                            max_grad_norm=1)

    if not args.do_train and args.do_predict:
        res = predict(model_S, eval_datasets, step=0, args=args)
        print(res)
Пример #6
0
def main():
    #parse arguments
    config.parse()
    args = config.args

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -  %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger = logging.getLogger("Main")
    #arguments check
    device, n_gpu = args_check(logger, args)
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
    if args.local_rank != -1:
        logger.warning(
            f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}"
        )

    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
    bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
    bert_config_S.num_labels = len(label2id_dict)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_examples = None
    train_dataset = None
    eval_examples = None
    eval_dataset = None
    num_train_steps = None

    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    if args.do_train:
        train_examples, train_dataset = read_features(
            args.train_file,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)
    if args.do_predict:
        eval_examples, eval_dataset = read_features(
            args.predict_file,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)

    if args.local_rank == 0:
        torch.distributed.barrier()
    #Build Model and load checkpoint
    model_S = ElectraForTokenClassification(bert_config_S)
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        #state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
        #missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,
                                                                strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    model_S.to(device)

    if args.do_train:
        #parameters
        if args.lr_decay is not None:
            outputs_params = list(model_S.classifier.named_parameters())
            outputs_params = divide_parameters(outputs_params,
                                               lr=args.learning_rate)

            electra_params = []
            n_layers = len(model_S.electra.encoder.layer)
            assert n_layers == 12
            for i, n in enumerate(reversed(range(n_layers))):
                encoder_params = list(
                    model_S.electra.encoder.layer[n].named_parameters())
                lr = args.learning_rate * args.lr_decay**(i + 1)
                electra_params += divide_parameters(encoder_params, lr=lr)
                logger.info(f"{i},{n},{lr}")
            embed_params = [
                (name, value)
                for name, value in model_S.electra.named_parameters()
                if 'embedding' in name
            ]
            logger.info(f"{[name for name,value in embed_params]}")
            lr = args.learning_rate * args.lr_decay**(n_layers + 1)
            electra_params += divide_parameters(embed_params, lr=lr)
            logger.info(f"embed lr:{lr}")
            all_trainable_params = outputs_params + electra_params
            assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\
                (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters())))
        else:
            params = list(model_S.named_parameters())
            all_trainable_params = divide_parameters(params,
                                                     lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)

        num_train_steps = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        optimizer = AdamW(all_trainable_params,
                          lr=args.learning_rate,
                          correct_bias=False)
        if args.official_schedule == 'const':
            scheduler_class = get_constant_schedule_with_warmup
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps)
            }
            #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps))
        elif args.official_schedule == 'linear':
            scheduler_class = get_linear_schedule_with_warmup
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps),
                'num_training_steps': num_train_steps
            }
            #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps)
        elif args.official_schedule == 'const_nowarmup':
            scheduler_class = get_constant_schedule
            scheduler_args = {}
        else:
            raise NotImplementedError

        logger.warning("***** Running training *****")
        logger.warning("local_rank %d Num orig examples = %d", args.local_rank,
                       len(train_examples))
        logger.warning("local_rank %d Num split examples = %d",
                       args.local_rank, len(train_dataset))
        logger.warning("local_rank %d Forward batch size = %d",
                       args.local_rank, forward_batch_size)
        logger.warning("local_rank %d Num backward steps = %d",
                       args.local_rank, num_train_steps)

        ########### TRAINING ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device,
            fp16=args.fp16,
            local_rank=args.local_rank)
        logger.info(f"{train_config}")

        distiller = BasicTrainer(
            train_config=train_config,
            model=model_S,
            adaptor=ElectraForTokenClassificationAdaptorTraining)

        # evluate the model in a single process in ddp_predict
        callback_func = partial(ddp_predict,
                                eval_examples=eval_examples,
                                eval_dataset=eval_dataset,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler_class=scheduler_class,
                            scheduler_args=scheduler_args,
                            max_grad_norm=1.0,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = ddp_predict(model_S,
                          eval_examples,
                          eval_dataset,
                          step=0,
                          args=args)
        print(res)
Пример #7
0
def main():
    #解析参数
    config.parse()
    args = config.args
    for k,v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #准备任务
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Student的配置
    if args.model_architecture == "electra":
        # 从transformers包中导入ElectraConfig, 并加载配置
        bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
        # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
        bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
        bert_config_S.output_attentions = (args.output_attention_layers=='true')
        # num_labels;类别个数
        bert_config_S.num_labels = num_labels
        assert args.max_seq_length <= bert_config_S.max_position_embeddings
    elif args.model_architecture == "albert":
        # 从transformers包中导入AlbertConfig, 并加载配置
        bert_config_S = AlbertConfig.from_json_file(args.bert_config_file_S)
        # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
        bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
        bert_config_S.output_attentions = (args.output_attention_layers=='true')
        # num_labels;类别个数
        bert_config_S.num_labels = num_labels
        assert args.max_seq_length <= bert_config_S.max_position_embeddings
    else:
        bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
        assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_dataset = None
    eval_datasets  = None
    num_train_steps = None
    # electra和bert都使用的bert的 tokenizer方式
    tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
    # 加载数据集, 计算steps
    if args.do_train:
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset])
        num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs
        logger.info("训练数据集已加载")
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
        for eval_task in eval_task_names:
            eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True))
        logger.info("预测数据集已加载")


    # Student的配置
    if args.model_architecture == "electra":
        #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型
        model_S = ElectraSPC(bert_config_S)
    elif args.model_architecture == "albert":
        model_S = AlbertSPC(bert_config_S)
    else:
        #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型
        model_S = BertSPCSimple(bert_config_S, num_labels=num_labels,args=args)
    #对加载后的student模型的参数进行初始化, 使用student模型预测
    if args.load_model_type=='bert' and args.model_architecture not in ["electra", "albert"]:
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.embeddings')}
            missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
            missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
            print(f"missing_keys,注意丢失的参数{missing_keys}")
        logger.info("Model loaded")
    elif args.load_model_type=='all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    elif args.model_architecture in ["electra", "albert"]:
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    else:
        logger.info("Model is randomly initialized.")
    #模型move to device
    model_S.to(device)
    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)

    if args.do_train:
        #parameters, params是模型的所有参数组成的列表
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("要训练的模型参数量组是,包括decay_group和no_decay_group: %d", len(all_trainable_params))
        # 优化器设置
        optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
                             warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
                             s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)

        logger.info("***** 开始训练 *****")
        logger.info("  样本数是 = %d", len(train_dataset))
        logger.info("  前向 batch size = %d", forward_batch_size)
        logger.info("  训练的steps = %d", num_train_steps)

        ########### 训练的配置 ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps = args.gradient_accumulation_steps,
            ckpt_frequency = args.ckpt_frequency,
            log_dir = args.output_dir,
            output_dir = args.output_dir,
            device = args.device)

        #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型
        distiller = BasicTrainer(train_config = train_config,
                                 model = model_S,
                                 adaptor = BertForGLUESimpleAdaptorTraining)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        #训练的dataloader
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
        #执行callbakc函数,对eval数据集
        callback_func = partial(predict, eval_datasets=eval_datasets, args=args)
        with distiller:
            #开始训练
            distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
                              num_epochs=args.num_train_epochs, callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,eval_datasets,step=0,args=args)
        print(res)
Пример #8
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k,v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_T = BertConfig.from_json_file(args.bert_config_file_T)
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_T.max_position_embeddings
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets  = None
    num_train_steps = None
    tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
    if args.do_train:
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset])
        num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
        for eval_task in eval_task_names:
            eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True))
    logger.info("Data loaded")


    #Build Model and load checkpoint

    model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels,args=args)
    #Load teacher
    if args.tuned_checkpoint_Ts:
        model_Ts = [BertForGLUESimple(bert_config_T, num_labels=num_labels,args=args) for i in range(len(args.tuned_checkpoint_Ts))]
        for model_T, ckpt_T in zip(model_Ts,args.tuned_checkpoint_Ts):
            logger.info("Load state dict %s" % ckpt_T)
            state_dict_T = torch.load(ckpt_T, map_location='cpu')
            model_T.load_state_dict(state_dict_T)
            model_T.eval()
    else:
        assert args.do_predict is True
    #Load student
    if args.load_model_type=='bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
        missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
        assert len(missing_keys)==0
    elif args.load_model_type=='all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    if args.do_train:
        for model_T in model_Ts:
            model_T.to(device)
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            if args.do_train:
                model_Ts = [torch.nn.DataParallel(model_T) for model_T in model_Ts]
            model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d", len(all_trainable_params))

        optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
                             warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
                             s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps = args.gradient_accumulation_steps,
            ckpt_frequency = args.ckpt_frequency,
            log_dir = args.output_dir,
            output_dir = args.output_dir,
            device = args.device)

        distill_config = DistillationConfig(
            temperature = args.temperature,
            kd_loss_type = 'ce')

        logger.info(f"{train_config}")
        logger.info(f"{distill_config}")
        adaptor = partial(BertForGLUESimpleAdaptor, no_logits=False, no_mask = False)


        distiller = MultiTeacherDistiller(train_config = train_config,
                            distill_config = distill_config,
                            model_T = model_Ts, model_S = model_S,
                            adaptor_T=adaptor,
                            adaptor_S=adaptor)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
        callback_func = partial(predict, eval_datasets=eval_datasets, args=args)
        with distiller:
            distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
                              num_epochs=args.num_train_epochs, callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,eval_datasets,step=0,args=args)
        print (res)
Пример #9
0
def main():
    #parse arguments
    config.parse()
    args = config.args

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -  %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger = logging.getLogger("Main")
    #arguments check
    device, n_gpu = args_check(logger, args)
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
    if args.local_rank != -1:
        logger.warning(
            f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}"
        )

    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #从transformers包中导入ElectraConfig, 并加载配置
    bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
    # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
    bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
    #num_labels;类别个数
    bert_config_S.num_labels = len(label2id_dict)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #读取数据
    train_examples = None
    train_dataset = None
    eval_examples = None
    eval_dataset = None
    num_train_steps = None
    # 加载Tokenizer
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    if args.do_train:
        # 返回所有的样本和样本形成的dataset格式, dataset的格式[all_token_ids,all_input_mask,all_label_ids]
        print("开始加载训练集数据")
        train_examples, train_dataset = read_features(
            args.train_file,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)
    if args.do_predict:
        print("开始加载测试集数据")
        eval_examples, eval_dataset = read_features(
            args.predict_file,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length)

    if args.local_rank == 0:
        torch.distributed.barrier()
    #Build Model and load checkpoint
    model_S = ElectraForTokenClassification(bert_config_S)
    #加载student模型的参数, 默认是bert类型
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        #state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
        #missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,
                                                                strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    #模型放到device
    model_S.to(device)

    if args.do_train:
        #parameters
        if args.lr_decay is not None:
            # 分类器层的参数 weight, bias
            outputs_params = list(model_S.classifier.named_parameters())
            #拆分出做学习率衰减的参数和不衰减的参数
            outputs_params = divide_parameters(outputs_params,
                                               lr=args.learning_rate)

            electra_params = []
            # eg: 12, encoder层共12层
            n_layers = len(model_S.electra.encoder.layer)
            assert n_layers == 12
            for i, n in enumerate(reversed(range(n_layers))):
                encoder_params = list(
                    model_S.electra.encoder.layer[n].named_parameters())
                lr = args.learning_rate * args.lr_decay**(i + 1)
                electra_params += divide_parameters(encoder_params, lr=lr)
                logger.info(f"{i}:第{n}层的学习率是:{lr}")
            embed_params = [
                (name, value)
                for name, value in model_S.electra.named_parameters()
                if 'embedding' in name
            ]
            logger.info(f"{[name for name,value in embed_params]}")
            lr = args.learning_rate * args.lr_decay**(n_layers + 1)
            electra_params += divide_parameters(embed_params, lr=lr)
            logger.info(f"embedding层的学习率 lr:{lr}")
            all_trainable_params = outputs_params + electra_params
            assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\
                (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters())))
        else:
            params = list(model_S.named_parameters())
            all_trainable_params = divide_parameters(params,
                                                     lr=args.learning_rate)
        logger.info("可训练的参数all_trainable_params共有: %d",
                    len(all_trainable_params))

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            train_sampler = DistributedSampler(train_dataset)
        #生成dataloader
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        # 根据epoch计算出运行多少steps
        num_train_steps = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        optimizer = AdamW(all_trainable_params,
                          lr=args.learning_rate,
                          correct_bias=False)
        if args.official_schedule == 'const':
            # 常数学习率
            scheduler_class = get_constant_schedule_with_warmup
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps)
            }
            #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps))
        elif args.official_schedule == 'linear':
            # 线性学习率
            scheduler_class = get_linear_schedule_with_warmup
            # warmup多少步,10%的step是warmup的
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps),
                'num_training_steps': num_train_steps
            }
            #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps)
        elif args.official_schedule == 'const_nowarmup':
            scheduler_class = get_constant_schedule
            scheduler_args = {}
        else:
            raise NotImplementedError

        logger.warning("***** 开始 训练 *****")
        logger.warning("local_rank %d 原样本数 = %d", args.local_rank,
                       len(train_examples))
        logger.warning("local_rank %d split之后的样本数 = %d", args.local_rank,
                       len(train_dataset))
        logger.warning("local_rank %d 前向 batch size = %d", args.local_rank,
                       forward_batch_size)
        logger.warning("local_rank %d 训练的steps = %d", args.local_rank,
                       num_train_steps)

        ########### TRAINING ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,  #保存频率
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device,
            fp16=args.fp16,
            local_rank=args.local_rank)

        logger.info(f"训练的配置文件:")
        logger.info(f"{train_config}")
        # 初始化训练器
        distiller = BasicTrainer(
            train_config=train_config,
            model=model_S,
            adaptor=ElectraForTokenClassificationAdaptorTraining)

        # 初始化callback函数,使用的ddp_predict函数进行评估
        callback_func = partial(ddp_predict,
                                eval_examples=eval_examples,
                                eval_dataset=eval_dataset,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler_class=scheduler_class,
                            scheduler_args=scheduler_args,
                            max_grad_norm=1.0,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = ddp_predict(model_S,
                          eval_examples,
                          eval_dataset,
                          step=0,
                          args=args)
        print(res)