def load_model(self):
        parser = argparse.ArgumentParser()
        args = parser.parse_args()
        args.output_encoded_layers = True
        args.output_attention_layers = True
        args.output_att_score = True
        args.output_att_sum = True
        self.args = args
        # 解析配置文件, 教师模型和student模型的vocab是不变的
        self.vocab_file = "bert_model/vocab.txt"
        # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型
        # student config:  config/chinese_bert_config_L4t.json
        # distil student model:  distil_model/gs8316.pkl
        self.bert_config_file_S = "bert_model/config.json"
        self.tuned_checkpoint_S = "trained_teacher_model/gs3024.pkl"
        self.max_seq_length = 70
        # 预测的batch_size大小
        self.predict_batch_size = 64
        # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
        bert_config_S = BertConfig.from_json_file(self.bert_config_file_S)

        # 加载tokenizer
        tokenizer = BertTokenizer(vocab_file=self.vocab_file)

        # 加载模型
        model_S = BertSPCSimple(bert_config_S,
                                num_labels=self.num_labels,
                                args=self.args)
        state_dict_S = torch.load(self.tuned_checkpoint_S,
                                  map_location=self.device)
        model_S.load_state_dict(state_dict_S)
        if self.verbose:
            print("模型已加载")

        return tokenizer, model_S
Exemple #2
0
    def load_macbert_model(self):
        parser = argparse.ArgumentParser()
        args = parser.parse_args()
        args.output_encoded_layers = True
        args.output_attention_layers = True
        args.output_att_score = True
        args.output_att_sum = True
        self.args = args
        # 解析配置文件, 教师模型和student模型的vocab是不变的
        self.vocab_file = "mac_bert_model/vocab.txt"
        # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型
        # student config:  config/chinese_bert_config_L4t.json
        # distil student model:  distil_model/gs8316.pkl
        self.bert_config_file_S = "mac_bert_model/config.json"
        self.tuned_checkpoint_S = "trained_teacher_model/macbert_2290_cosmetics_weibo.pkl"
        # self.tuned_checkpoint_S = "trained_teacher_model/macbert_894_cosmetics.pkl"
        # self.tuned_checkpoint_S = "trained_teacher_model/macbert_teacher_max75len_5000.pkl"
        # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
        bert_config_S = BertConfig.from_json_file(self.bert_config_file_S)

        # 加载tokenizer
        tokenizer = BertTokenizer(vocab_file=self.vocab_file)

        # 加载模型
        model_S = BertSPCSimple(bert_config_S,
                                num_labels=self.num_labels,
                                args=self.args)
        state_dict_S = torch.load(self.tuned_checkpoint_S,
                                  map_location=self.device)
        model_S.load_state_dict(state_dict_S)
        if self.verbose:
            print("模型已加载")
        self.predict_tokenizer = tokenizer
        self.predict_model = model_S
        logger.info(f"macbert预测模型加载完成")
Exemple #3
0
    def load_train_model(self):
        """
        初始化训练的模型
        :return:
        """
        parser = argparse.ArgumentParser()
        args = parser.parse_args()
        args.output_encoded_layers = True
        args.output_attention_layers = True
        args.output_att_score = True
        args.output_att_sum = True
        self.learning_rate = 2e-05
        #学习率 warmup的比例
        self.warmup_proportion = 0.1
        self.num_train_epochs = 1
        #使用的学习率scheduler
        self.schedule = 'slanted_triangular'
        self.s_opt1 = 30.0
        self.s_opt2 = 0.0
        self.s_opt3 = 1.0
        self.weight_decay_rate = 0.01
        #训练多少epcoh保存一次模型
        self.ckpt_frequency = 1
        #模型和日志保存的位置
        self.output_dir = "output_root_dir/train_api"
        #梯度累积步数
        self.gradient_accumulation_steps = 1
        self.args = args
        # 解析配置文件, 教师模型和student模型的vocab是不变的
        self.vocab_file = "mac_bert_model/vocab.txt"
        self.bert_config_file_S = "mac_bert_model/config.json"
        self.tuned_checkpoint_S = "mac_bert_model/pytorch_model.bin"
        # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
        bert_config_S = BertConfig.from_json_file(self.bert_config_file_S)

        # 加载tokenizer
        tokenizer = BertTokenizer(vocab_file=self.vocab_file)

        # 加载模型
        model_S = BertSPCSimple(bert_config_S,
                                num_labels=self.num_labels,
                                args=self.args)
        state_dict_S = torch.load(self.tuned_checkpoint_S,
                                  map_location=self.device)
        state_weight = {
            k[5:]: v
            for k, v in state_dict_S.items() if k.startswith('bert.')
        }
        missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                       strict=False)
        #验证下参数没有丢失
        assert len(missing_keys) == 0
        self.train_tokenizer = tokenizer
        self.train_model = model_S
        logger.info(f"训练模型{self.tuned_checkpoint_S}加载完成")
Exemple #4
0
def main():
    # 解析参数
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")

    # 解析参数, 判断使用的设备
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    # 准备task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    # 所有的labels
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # 读取数据
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)

    eval_dataset = load_and_cache_examples(args,
                                           args.task_name,
                                           tokenizer,
                                           evaluate=True)
    logger.info("评估数据集已加载")

    model_S = BertSPCSimple(bert_config_S, num_labels=num_labels, args=args)
    # 加载student模型
    assert args.tuned_checkpoint_S is not None
    state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
    model_S.load_state_dict(state_dict_S)
    logger.info("Student模型已加载")

    # 开始预测
    res = predict(model_S, eval_dataset, args=args)
    print(res)
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets = None
    num_train_steps = None
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)
    # 加载数据集, 计算steps
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                args.task_name,
                                                tokenizer,
                                                evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args,
                                                        args.aux_task_name,
                                                        tokenizer,
                                                        evaluate=False,
                                                        is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset(
                [train_dataset, aux_train_dataset])
        num_train_steps = int(
            len(train_dataset) / args.train_batch_size) * args.num_train_epochs
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli",
                           "mnli-mm") if args.task_name == "mnli" else (
                               args.task_name, )
        for eval_task in eval_task_names:
            eval_datasets.append(
                load_and_cache_examples(args,
                                        eval_task,
                                        tokenizer,
                                        evaluate=True))
    logger.info("数据集已加载")

    #加载模型并初始化, 只用student模型,其实这里相当于在MNLI数据上训练教师模型,只训练一个模型
    model_S = BertForGLUESimple(bert_config_S,
                                num_labels=num_labels,
                                args=args)
    #初始化student模型
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items()
                if k.startswith('bert.embeddings')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items() if k.startswith('bert.')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            assert len(missing_keys) == 0
        logger.info("Model loaded")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    else:
        logger.info("Model is randomly initialized.")
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))
        # 优化器设置
        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### 蒸馏 ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)

        #执行监督训练,而不是蒸馏。它可以用于训练teacher模型。初始化模型
        distiller = BasicTrainer(train_config=train_config,
                                 model=model_S,
                                 adaptor=BertForGLUESimpleAdaptorTraining)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        callback_func = partial(predict,
                                eval_datasets=eval_datasets,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S, eval_datasets, step=0, args=args)
        print(res)
Exemple #6
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_T = BertConfig.from_json_file(args.bert_config_file_T)
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_T.max_position_embeddings
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets = None
    num_train_steps = None
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)
    # 加载数据集
    if args.do_train:
        train_dataset, examples = load_and_cache_examples(args,
                                                          args.task_name,
                                                          tokenizer,
                                                          evaluate=False)
        if args.aux_task_name:
            aux_train_dataset, examples = load_and_cache_examples(
                args,
                args.aux_task_name,
                tokenizer,
                evaluate=False,
                is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset(
                [train_dataset, aux_train_dataset])
        num_train_steps = int(
            len(train_dataset) / args.train_batch_size) * args.num_train_epochs
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli",
                           "mnli-mm") if args.task_name == "mnli" else (
                               args.task_name, )
        for eval_task in eval_task_names:
            eval_dataset, examples = load_and_cache_examples(args,
                                                             eval_task,
                                                             tokenizer,
                                                             evaluate=True)
            eval_datasets.append(eval_dataset)
    logger.info("数据集加载成功")

    #加载模型,加载teacher和student模型
    model_T = BertForGLUESimple(bert_config_T,
                                num_labels=num_labels,
                                args=args)
    model_S = BertForGLUESimple(bert_config_S,
                                num_labels=num_labels,
                                args=args)
    #加载teacher模型参数
    if args.tuned_checkpoint_T is not None:
        state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu')
        model_T.load_state_dict(state_dict_T)
        model_T.eval()
    else:
        assert args.do_predict is True
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items()
                if k.startswith('bert.embeddings')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items() if k.startswith('bert.')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            assert len(missing_keys) == 0
        logger.info("Model loaded")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    else:
        logger.info("Student模型没有可加载参数,随机初始化参数 randomly initialized.")
    model_T.to(device)
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_T = torch.nn.DataParallel(model_T)  #,output_device=n_gpu-1)
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))
        #优化器配置
        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)
        # 定义了一些固定的matches配置文件
        from matches import matches
        intermediate_matches = None
        if isinstance(args.matches, (list, tuple)):
            intermediate_matches = []
            for match in args.matches:
                intermediate_matches += matches[match]
        logger.info(f"中间层match信息: {intermediate_matches}")
        distill_config = DistillationConfig(
            temperature=args.temperature,
            intermediate_matches=intermediate_matches)

        logger.info(f"训练配置: {train_config}")
        logger.info(f"蒸馏配置: {distill_config}")
        adaptor_T = partial(BertForGLUESimpleAdaptor,
                            no_logits=args.no_logits,
                            no_mask=args.no_inputs_mask)
        adaptor_S = partial(BertForGLUESimpleAdaptor,
                            no_logits=args.no_logits,
                            no_mask=args.no_inputs_mask)
        # 支持中间状态匹配的通用蒸馏模型
        distiller = GeneralDistiller(train_config=train_config,
                                     distill_config=distill_config,
                                     model_T=model_T,
                                     model_S=model_S,
                                     adaptor_T=adaptor_T,
                                     adaptor_S=adaptor_S)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        callback_func = partial(predict,
                                eval_datasets=eval_datasets,
                                args=args,
                                examples=examples)
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,
                      eval_datasets,
                      step=0,
                      args=args,
                      examples=examples,
                      label_list=label_list)
        print(res)
Exemple #7
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_examples = None
    train_features = None
    eval_examples = None
    eval_features = None
    num_train_steps = None

    tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file,
                                     do_lower_case=args.do_lower_case)
    convert_fn = partial(convert_examples_to_features,
                         tokenizer=tokenizer,
                         max_seq_length=args.max_seq_length,
                         doc_stride=args.doc_stride,
                         max_query_length=args.max_query_length)
    if args.do_train:
        train_examples, train_features = read_and_convert(
            args.train_file,
            is_training=True,
            do_lower_case=args.do_lower_case,
            read_fn=read_squad_examples,
            convert_fn=convert_fn)
        if args.fake_file_1:
            fake_examples1, fake_features1 = read_and_convert(
                args.fake_file_1,
                is_training=True,
                do_lower_case=args.do_lower_case,
                read_fn=read_squad_examples,
                convert_fn=convert_fn)
            train_examples += fake_examples1
            train_features += fake_features1
        if args.fake_file_2:
            fake_examples2, fake_features2 = read_and_convert(
                args.fake_file_2,
                is_training=True,
                do_lower_case=args.do_lower_case,
                read_fn=read_squad_examples,
                convert_fn=convert_fn)
            train_examples += fake_examples2
            train_features += fake_features2

        num_train_steps = int(len(train_features) /
                              args.train_batch_size) * args.num_train_epochs

    if args.do_predict:
        eval_examples, eval_features = read_and_convert(
            args.predict_file,
            is_training=False,
            do_lower_case=args.do_lower_case,
            read_fn=read_squad_examples,
            convert_fn=convert_fn)

    #Build Model and load checkpoint
    model_S = BertForQASimple(bert_config_S, args)
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        state_weight = {
            k[5:]: v
            for k, v in state_dict_S.items() if k.startswith('bert.')
        }
        missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                       strict=False)
        assert len(missing_keys) == 0
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))

        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)

        distiller = BasicTrainer(train_config=train_config,
                                 model=model_S,
                                 adaptor=BertForQASimpleAdaptorTraining)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_doc_mask = torch.tensor([f.doc_mask for f in train_features],
                                    dtype=torch.float)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_start_positions = torch.tensor(
            [f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor(
            [f.end_position for f in train_features], dtype=torch.long)

        train_dataset = TensorDataset(all_input_ids, all_segment_ids,
                                      all_input_mask, all_doc_mask,
                                      all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        callback_func = partial(predict,
                                eval_examples=eval_examples,
                                eval_features=eval_features,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S, eval_examples, eval_features, step=0, args=args)
        print(res)
Exemple #8
0
def main():
    #解析参数
    config.parse()
    args = config.args
    for k,v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #准备任务
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Student的配置
    if args.model_architecture == "electra":
        # 从transformers包中导入ElectraConfig, 并加载配置
        bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
        # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
        bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
        bert_config_S.output_attentions = (args.output_attention_layers=='true')
        # num_labels;类别个数
        bert_config_S.num_labels = num_labels
        assert args.max_seq_length <= bert_config_S.max_position_embeddings
    elif args.model_architecture == "albert":
        # 从transformers包中导入AlbertConfig, 并加载配置
        bert_config_S = AlbertConfig.from_json_file(args.bert_config_file_S)
        # (args.output_encoded_layers=='true')  --> True, 默认输出隐藏层的状态
        bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
        bert_config_S.output_attentions = (args.output_attention_layers=='true')
        # num_labels;类别个数
        bert_config_S.num_labels = num_labels
        assert args.max_seq_length <= bert_config_S.max_position_embeddings
    else:
        bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
        assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_dataset = None
    eval_datasets  = None
    num_train_steps = None
    # electra和bert都使用的bert的 tokenizer方式
    tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
    # 加载数据集, 计算steps
    if args.do_train:
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset])
        num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs
        logger.info("训练数据集已加载")
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
        for eval_task in eval_task_names:
            eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True))
        logger.info("预测数据集已加载")


    # Student的配置
    if args.model_architecture == "electra":
        #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型
        model_S = ElectraSPC(bert_config_S)
    elif args.model_architecture == "albert":
        model_S = AlbertSPC(bert_config_S)
    else:
        #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型
        model_S = BertSPCSimple(bert_config_S, num_labels=num_labels,args=args)
    #对加载后的student模型的参数进行初始化, 使用student模型预测
    if args.load_model_type=='bert' and args.model_architecture not in ["electra", "albert"]:
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.embeddings')}
            missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
            missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
            print(f"missing_keys,注意丢失的参数{missing_keys}")
        logger.info("Model loaded")
    elif args.load_model_type=='all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    elif args.model_architecture in ["electra", "albert"]:
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    else:
        logger.info("Model is randomly initialized.")
    #模型move to device
    model_S.to(device)
    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)

    if args.do_train:
        #parameters, params是模型的所有参数组成的列表
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("要训练的模型参数量组是,包括decay_group和no_decay_group: %d", len(all_trainable_params))
        # 优化器设置
        optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
                             warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
                             s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)

        logger.info("***** 开始训练 *****")
        logger.info("  样本数是 = %d", len(train_dataset))
        logger.info("  前向 batch size = %d", forward_batch_size)
        logger.info("  训练的steps = %d", num_train_steps)

        ########### 训练的配置 ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps = args.gradient_accumulation_steps,
            ckpt_frequency = args.ckpt_frequency,
            log_dir = args.output_dir,
            output_dir = args.output_dir,
            device = args.device)

        #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型
        distiller = BasicTrainer(train_config = train_config,
                                 model = model_S,
                                 adaptor = BertForGLUESimpleAdaptorTraining)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        #训练的dataloader
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
        #执行callbakc函数,对eval数据集
        callback_func = partial(predict, eval_datasets=eval_datasets, args=args)
        with distiller:
            #开始训练
            distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
                              num_epochs=args.num_train_epochs, callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,eval_datasets,step=0,args=args)
        print(res)
Exemple #9
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k,v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_T = BertConfig.from_json_file(args.bert_config_file_T)
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_T.max_position_embeddings
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets  = None
    num_train_steps = None
    tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
    if args.do_train:
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset])
        num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
        for eval_task in eval_task_names:
            eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True))
    logger.info("Data loaded")


    #Build Model and load checkpoint

    model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels,args=args)
    #Load teacher
    if args.tuned_checkpoint_Ts:
        model_Ts = [BertForGLUESimple(bert_config_T, num_labels=num_labels,args=args) for i in range(len(args.tuned_checkpoint_Ts))]
        for model_T, ckpt_T in zip(model_Ts,args.tuned_checkpoint_Ts):
            logger.info("Load state dict %s" % ckpt_T)
            state_dict_T = torch.load(ckpt_T, map_location='cpu')
            model_T.load_state_dict(state_dict_T)
            model_T.eval()
    else:
        assert args.do_predict is True
    #Load student
    if args.load_model_type=='bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
        missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
        assert len(missing_keys)==0
    elif args.load_model_type=='all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    if args.do_train:
        for model_T in model_Ts:
            model_T.to(device)
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            if args.do_train:
                model_Ts = [torch.nn.DataParallel(model_T) for model_T in model_Ts]
            model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d", len(all_trainable_params))

        optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
                             warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
                             s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps = args.gradient_accumulation_steps,
            ckpt_frequency = args.ckpt_frequency,
            log_dir = args.output_dir,
            output_dir = args.output_dir,
            device = args.device)

        distill_config = DistillationConfig(
            temperature = args.temperature,
            kd_loss_type = 'ce')

        logger.info(f"{train_config}")
        logger.info(f"{distill_config}")
        adaptor = partial(BertForGLUESimpleAdaptor, no_logits=False, no_mask = False)


        distiller = MultiTeacherDistiller(train_config = train_config,
                            distill_config = distill_config,
                            model_T = model_Ts, model_S = model_S,
                            adaptor_T=adaptor,
                            adaptor_S=adaptor)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
        callback_func = partial(predict, eval_datasets=eval_datasets, args=args)
        with distiller:
            distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
                              num_epochs=args.num_train_epochs, callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,eval_datasets,step=0,args=args)
        print (res)