コード例 #1
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load config
    teachers_and_student = parse_model_config(args.model_config_json)

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets = None
    num_train_steps = None

    tokenizer_S = teachers_and_student['student']['tokenizer']
    prefix_S = teachers_and_student['student']['prefix']

    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                args.task_name,
                                                tokenizer_S,
                                                prefix=prefix_S,
                                                evaluate=False)
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli",
                           "mnli-mm") if args.task_name == "mnli" else (
                               args.task_name, )
        for eval_task in eval_task_names:
            eval_datasets.append(
                load_and_cache_examples(args,
                                        eval_task,
                                        tokenizer_S,
                                        prefix=prefix_S,
                                        evaluate=True))
    logger.info("Data loaded")

    #Build Model and load checkpoint
    if args.do_train:
        model_Ts = []
        for teacher in teachers_and_student['teachers']:
            model_type_T = teacher['model_type']
            model_config_T = teacher['config']
            checkpoint_T = teacher['checkpoint']

            _, _, model_class_T = MODEL_CLASSES[model_type_T]
            model_T = model_class_T(model_config_T, num_labels=num_labels)
            state_dict_T = torch.load(checkpoint_T, map_location='cpu')
            missing_keys, un_keys = model_T.load_state_dict(state_dict_T,
                                                            strict=True)
            logger.info(f"Teacher Model {model_type_T} loaded")
            model_T.to(device)
            model_T.eval()
            model_Ts.append(model_T)

    student = teachers_and_student['student']
    model_type_S = student['model_type']
    model_config_S = student['config']
    checkpoint_S = student['checkpoint']
    _, _, model_class_S = MODEL_CLASSES[model_type_S]
    model_S = model_class_S(model_config_S, num_labels=num_labels)
    if checkpoint_S is not None:
        state_dict_S = torch.load(checkpoint_S, map_location='cpu')
        missing_keys, un_keys = model_S.load_state_dict(state_dict_S,
                                                        strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{un_keys}")
    else:
        logger.warning("Initializing student randomly")
    logger.info("Student Model loaded")
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            if args.do_train:
                model_Ts = [
                    torch.nn.DataParallel(model_T) for model_T in model_Ts
                ]
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        num_train_steps = int(
            len(train_dataloader) // args.gradient_accumulation_steps *
            args.num_train_epochs)

        ########## DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            fp16=args.fp16,
            device=args.device)

        distill_config = DistillationConfig(temperature=args.temperature,
                                            kd_loss_type='ce')

        logger.info(f"{train_config}")
        logger.info(f"{distill_config}")
        adaptor_T = BertForGLUESimpleAdaptor
        adaptor_S = BertForGLUESimpleAdaptor

        distiller = MultiTeacherDistiller(train_config=train_config,
                                          distill_config=distill_config,
                                          model_T=model_Ts,
                                          model_S=model_S,
                                          adaptor_T=adaptor_T,
                                          adaptor_S=adaptor_S)

        optimizer = AdamW(all_trainable_params, lr=args.learning_rate)
        scheduler_class = get_linear_schedule_with_warmup
        scheduler_args = {
            'num_warmup_steps': int(args.warmup_proportion * num_train_steps),
            'num_training_steps': num_train_steps
        }

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        callback_func = partial(predict,
                                eval_datasets=eval_datasets,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler_class=scheduler_class,
                            scheduler_args=scheduler_args,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func,
                            max_grad_norm=1)

    if not args.do_train and args.do_predict:
        res = predict(model_S, eval_datasets, step=0, args=args)
        print(res)
コード例 #2
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_T = BertConfig.from_json_file(args.bert_config_file_T)
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_T.max_position_embeddings
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets = None
    num_train_steps = None
    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)
    # 加载数据集
    if args.do_train:
        train_dataset, examples = load_and_cache_examples(args,
                                                          args.task_name,
                                                          tokenizer,
                                                          evaluate=False)
        if args.aux_task_name:
            aux_train_dataset, examples = load_and_cache_examples(
                args,
                args.aux_task_name,
                tokenizer,
                evaluate=False,
                is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset(
                [train_dataset, aux_train_dataset])
        num_train_steps = int(
            len(train_dataset) / args.train_batch_size) * args.num_train_epochs
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli",
                           "mnli-mm") if args.task_name == "mnli" else (
                               args.task_name, )
        for eval_task in eval_task_names:
            eval_dataset, examples = load_and_cache_examples(args,
                                                             eval_task,
                                                             tokenizer,
                                                             evaluate=True)
            eval_datasets.append(eval_dataset)
    logger.info("数据集加载成功")

    #加载模型,加载teacher和student模型
    model_T = BertForGLUESimple(bert_config_T,
                                num_labels=num_labels,
                                args=args)
    model_S = BertForGLUESimple(bert_config_S,
                                num_labels=num_labels,
                                args=args)
    #加载teacher模型参数
    if args.tuned_checkpoint_T is not None:
        state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu')
        model_T.load_state_dict(state_dict_T)
        model_T.eval()
    else:
        assert args.do_predict is True
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        if args.only_load_embedding:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items()
                if k.startswith('bert.embeddings')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            logger.info(f"Missing keys {list(missing_keys)}")
        else:
            state_weight = {
                k[5:]: v
                for k, v in state_dict_S.items() if k.startswith('bert.')
            }
            missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                           strict=False)
            assert len(missing_keys) == 0
        logger.info("Model loaded")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
        logger.info("Model loaded")
    else:
        logger.info("Student模型没有可加载参数,随机初始化参数 randomly initialized.")
    model_T.to(device)
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_T = torch.nn.DataParallel(model_T)  #,output_device=n_gpu-1)
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))
        #优化器配置
        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)
        # 定义了一些固定的matches配置文件
        from matches import matches
        intermediate_matches = None
        if isinstance(args.matches, (list, tuple)):
            intermediate_matches = []
            for match in args.matches:
                intermediate_matches += matches[match]
        logger.info(f"中间层match信息: {intermediate_matches}")
        distill_config = DistillationConfig(
            temperature=args.temperature,
            intermediate_matches=intermediate_matches)

        logger.info(f"训练配置: {train_config}")
        logger.info(f"蒸馏配置: {distill_config}")
        adaptor_T = partial(BertForGLUESimpleAdaptor,
                            no_logits=args.no_logits,
                            no_mask=args.no_inputs_mask)
        adaptor_S = partial(BertForGLUESimpleAdaptor,
                            no_logits=args.no_logits,
                            no_mask=args.no_inputs_mask)
        # 支持中间状态匹配的通用蒸馏模型
        distiller = GeneralDistiller(train_config=train_config,
                                     distill_config=distill_config,
                                     model_T=model_T,
                                     model_S=model_S,
                                     adaptor_T=adaptor_T,
                                     adaptor_S=adaptor_S)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        callback_func = partial(predict,
                                eval_datasets=eval_datasets,
                                args=args,
                                examples=examples)
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,
                      eval_datasets,
                      step=0,
                      args=args,
                      examples=examples,
                      label_list=label_list)
        print(res)
コード例 #3
0
def train(args, train_dataset, model_T, model, tokenizer, labels,
          pad_token_label_id, predict_callback):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        args.weight_decay
    }, {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(args.warmup_steps * t_total),
        num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    if args.do_train and args.do_distill:
        from textbrewer import DistillationConfig, TrainingConfig, GeneralDistiller
        distill_config = DistillationConfig(
            temperature=8,
            # intermediate_matches = [{'layer_T':10, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1}]
        )
        train_config = TrainingConfig(device="cuda",
                                      log_dir=args.output_dir,
                                      output_dir=args.output_dir)

        def adaptor_T(batch, model_output):
            return {
                "logits": (model_output[1], ),
                'logits_mask': (batch['attention_mask'], )
            }

        def adaptor_S(batch, model_output):
            return {
                "logits": (model_output[1], ),
                'logits_mask': (batch['attention_mask'], )
            }

        distiller = GeneralDistiller(
            train_config,
            distill_config,
            model_T,
            model,
            adaptor_T,
            adaptor_S,
        )
        distiller.train(optimizer,
                        scheduler,
                        train_dataloader,
                        args.num_train_epochs,
                        callback=predict_callback)
        return
コード例 #4
0
ファイル: main.distill.py プロジェクト: vincentlux/TextBrewer
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_T = BertConfig.from_json_file(args.bert_config_file_T)
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_T.max_position_embeddings
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_examples = None
    train_features = None
    eval_examples = None
    eval_features = None
    num_train_steps = None

    tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file,
                                     do_lower_case=args.do_lower_case)
    convert_fn = partial(convert_examples_to_features,
                         tokenizer=tokenizer,
                         max_seq_length=args.max_seq_length,
                         doc_stride=args.doc_stride,
                         max_query_length=args.max_query_length)
    if args.do_train:
        train_examples, train_features = read_and_convert(
            args.train_file,
            is_training=True,
            do_lower_case=args.do_lower_case,
            read_fn=read_squad_examples,
            convert_fn=convert_fn)
        if args.fake_file_1:
            fake_examples1, fake_features1 = read_and_convert(
                args.fake_file_1,
                is_training=True,
                do_lower_case=args.do_lower_case,
                read_fn=read_squad_examples,
                convert_fn=convert_fn)
            train_examples += fake_examples1
            train_features += fake_features1
        if args.fake_file_2:
            fake_examples2, fake_features2 = read_and_convert(
                args.fake_file_2,
                is_training=True,
                do_lower_case=args.do_lower_case,
                read_fn=read_squad_examples,
                convert_fn=convert_fn)
            train_examples += fake_examples2
            train_features += fake_features2

        num_train_steps = int(len(train_features) /
                              args.train_batch_size) * args.num_train_epochs

    if args.do_predict:
        eval_examples, eval_features = read_and_convert(
            args.predict_file,
            is_training=False,
            do_lower_case=args.do_lower_case,
            read_fn=read_squad_examples,
            convert_fn=convert_fn)

    #Build Model and load checkpoint
    model_T = BertForQASimple(bert_config_T, args)
    model_S = BertForQASimple(bert_config_S, args)
    #Load teacher
    if args.tuned_checkpoint_T is not None:
        state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu')
        model_T.load_state_dict(state_dict_T)
        model_T.eval()
    else:
        assert args.do_predict is True
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        # state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
        state_weight = {k: v for k, v in state_dict_S.items()}
        missing_keys, _ = model_S.bert.load_state_dict(state_weight,
                                                       strict=False)
        assert len(missing_keys) == 0
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    model_T.to(device)
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            model_T = torch.nn.DataParallel(model_T)  #,output_device=n_gpu-1)
            model_S = torch.nn.DataParallel(model_S)  #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))

        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)

        from matches import matches
        intermediate_matches = None
        if isinstance(args.matches, (list, tuple)):
            intermediate_matches = []
            for match in args.matches:
                intermediate_matches += matches[match]
        logger.info(f"{intermediate_matches}")
        distill_config = DistillationConfig(
            temperature=args.temperature,
            intermediate_matches=intermediate_matches)

        adaptor_T = partial(BertForQASimpleAdaptor,
                            no_logits=args.no_logits,
                            no_mask=args.no_inputs_mask)
        adaptor_S = partial(BertForQASimpleAdaptor,
                            no_logits=args.no_logits,
                            no_mask=args.no_inputs_mask)

        distiller = GeneralDistiller(train_config=train_config,
                                     distill_config=distill_config,
                                     model_T=model_T,
                                     model_S=model_S,
                                     adaptor_T=adaptor_T,
                                     adaptor_S=adaptor_S)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_doc_mask = torch.tensor([f.doc_mask for f in train_features],
                                    dtype=torch.float)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_start_positions = torch.tensor(
            [f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor(
            [f.end_position for f in train_features], dtype=torch.long)

        train_dataset = TensorDataset(all_input_ids, all_segment_ids,
                                      all_input_mask, all_doc_mask,
                                      all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)
        callback_func = partial(predict,
                                eval_examples=eval_examples,
                                eval_features=eval_features,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S, eval_examples, eval_features, step=0, args=args)
        print(res)
コード例 #5
0
ファイル: distill.py プロジェクト: zmskye/TextBrewer
train_config = TrainingConfig(device=device)
distill_config = DistillationConfig(temperature=8,
                                    hard_label_weight=0,
                                    kd_loss_type='ce',
                                    probability_shift=False,
                                    intermediate_matches=[{
                                        'layer_T': 0,
                                        'layer_S': 0,
                                        'feature': 'hidden',
                                        'loss': 'hidden_mse',
                                        'weight': 1
                                    }, {
                                        'layer_T': 8,
                                        'layer_S': 2,
                                        'feature': 'hidden',
                                        'loss': 'hidden_mse',
                                        'weight': 1
                                    }, {
                                        'layer_T': [0, 0],
                                        'layer_S': [0, 0],
                                        'feature': 'hidden',
                                        'loss': 'nst',
                                        'weight': 1
                                    }, {
                                        'layer_T': [8, 8],
                                        'layer_S': [2, 2],
                                        'feature': 'hidden',
                                        'loss': 'nst',
                                        'weight': 1
                                    }])
コード例 #6
0
ファイル: snippet.py プロジェクト: zuiwufenghua/TextBrewer
from textbrewer import TrainingConfig, DistillationConfig

# We omit the initialization of models, optimizer, and dataloader. 
teacher_model : torch.nn.Module = ...
student_model : torch.nn.Module = ...
dataloader : torch.utils.data.DataLoader = ...
optimizer : torch.optim.Optimizer = ...
scheduler : torch.optim.lr_scheduler = ...

def simple_adaptor(batch, model_outputs):
    # We assume that the first element of model_outputs 
    # is the logits before softmax
    return {'logits': model_outputs[0]}  

train_config = TrainingConfig()
distill_config = DistillationConfig()
distiller = GeneralDistiller(
    train_config=train_config, distill_config = distill_config,
    model_T = teacher_model, model_S = student_model, 
    adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)

distiller.train(optimizer, scheduler, 
    dataloader, num_epochs, callback=None)





def predict(model, eval_dataset, step, args): 
  raise NotImplementedError
# fill other arguments
コード例 #7
0
def main():

    config_distil.parse()
    global args
    args = config_distil.args
    global logger
    logger = create_logger(args.log_file)

    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")

    # set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    logger.info('加载字典')
    word2idx = load_chinese_base_vocab()
    # 判断是否有可用GPU
    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and args.is_cuda else "cpu")

    logger.info('using device:{}'.format(args.device))
    # 定义模型超参数
    bertconfig_T = BertConfig(vocab_size=len(word2idx))
    bertconfig_S = BertConfig(vocab_size=len(word2idx), num_hidden_layers=3)
    logger.info('初始化BERT模型')
    bert_model_T = Seq2SeqModel(config=bertconfig_T)
    bert_model_S = Seq2SeqModel(config=bertconfig_S)
    logger.info('加载Teacher模型~')
    load_model(bert_model_T, args.tuned_checkpoint_T)
    logger.info('将模型发送到计算设备(GPU或CPU)')
    bert_model_T.to(args.device)
    bert_model_T.eval()

    logger.info('加载Student模型~')
    if args.load_model_type == 'bert':
        load_model(bert_model_S, args.init_checkpoint_S)
    else:
        logger.info(" Student Model is randomly initialized.")
    logger.info('将模型发送到计算设备(GPU或CPU)')
    bert_model_S.to(args.device)
    # 声明自定义的数据加载器

    logger.info('加载训练数据')
    train = SelfDataset(args.train_path, args.max_length)
    trainloader = DataLoader(train,
                             batch_size=args.train_batch_size,
                             shuffle=True,
                             collate_fn=collate_fn)

    if args.do_train:

        logger.info(' 声明需要优化的参数')
        num_train_steps = int(
            len(trainloader) / args.train_batch_size) * args.num_train_epochs
        optim_parameters = list(bert_model_S.named_parameters())
        all_trainable_params = divide_parameters(optim_parameters,
                                                 lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))
        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)

        from generative.matches import matches
        intermediate_matches = None

        if isinstance(args.matches, (list, tuple)):
            intermediate_matches = []
            for match in args.matches:
                intermediate_matches += matches[match]
        intermediate_matches = []
        for match in args.matches:
            intermediate_matches += matches[match]

        logger.info(f"{intermediate_matches}")
        distill_config = DistillationConfig(
            temperature=args.temperature,
            intermediate_matches=intermediate_matches)

        def BertForS2SSimpleAdaptor(batch, model_outputs):
            return {'hidden': model_outputs[0], 'logits': model_outputs[1], 'loss': model_outputs[2], 'attention': model_outputs[3]}

        adaptor_T = partial(BertForS2SSimpleAdaptor)
        adaptor_S = partial(BertForS2SSimpleAdaptor)

        distiller = GeneralDistiller(train_config=train_config,
                                     distill_config=distill_config,
                                     model_T=bert_model_T,
                                     model_S=bert_model_S,
                                     adaptor_T=adaptor_T,
                                     adaptor_S=adaptor_S)
        callback_func = partial(predict, data_path=args.dev_path, args=args)
        logger.info('Start distillation.')
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=trainloader,
                            num_epochs=args.num_train_epochs,
                            callback=None)

    if not args.do_train and args.do_predict:
        res = predict(bert_model_S, args.test_path, step=0, args=args)
        print(res)
コード例 #8
0
def main():
    #parse arguments
    config.parse()
    args = config.args
    for k,v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    #arguments check
    device, n_gpu = args_check(args)
    os.makedirs(args.output_dir, exist_ok=True)
    forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_T = BertConfig.from_json_file(args.bert_config_file_T)
    bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
    assert args.max_seq_length <= bert_config_T.max_position_embeddings
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #Prepare GLUE task
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    #read data
    train_dataset = None
    eval_datasets  = None
    num_train_steps = None
    tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
    if args.do_train:
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        if args.aux_task_name:
            aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True)
            train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset])
        num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs
    if args.do_predict:
        eval_datasets = []
        eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
        for eval_task in eval_task_names:
            eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True))
    logger.info("Data loaded")


    #Build Model and load checkpoint

    model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels,args=args)
    #Load teacher
    if args.tuned_checkpoint_Ts:
        model_Ts = [BertForGLUESimple(bert_config_T, num_labels=num_labels,args=args) for i in range(len(args.tuned_checkpoint_Ts))]
        for model_T, ckpt_T in zip(model_Ts,args.tuned_checkpoint_Ts):
            logger.info("Load state dict %s" % ckpt_T)
            state_dict_T = torch.load(ckpt_T, map_location='cpu')
            model_T.load_state_dict(state_dict_T)
            model_T.eval()
    else:
        assert args.do_predict is True
    #Load student
    if args.load_model_type=='bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
        missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
        assert len(missing_keys)==0
    elif args.load_model_type=='all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    if args.do_train:
        for model_T in model_Ts:
            model_T.to(device)
    model_S.to(device)

    if args.local_rank != -1 or n_gpu > 1:
        if args.local_rank != -1:
            raise NotImplementedError
        elif n_gpu > 1:
            if args.do_train:
                model_Ts = [torch.nn.DataParallel(model_T) for model_T in model_Ts]
            model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)

    if args.do_train:
        #parameters
        params = list(model_S.named_parameters())
        all_trainable_params = divide_parameters(params, lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d", len(all_trainable_params))

        optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
                             warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
                             s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Forward batch size = %d", forward_batch_size)
        logger.info("  Num backward steps = %d", num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps = args.gradient_accumulation_steps,
            ckpt_frequency = args.ckpt_frequency,
            log_dir = args.output_dir,
            output_dir = args.output_dir,
            device = args.device)

        distill_config = DistillationConfig(
            temperature = args.temperature,
            kd_loss_type = 'ce')

        logger.info(f"{train_config}")
        logger.info(f"{distill_config}")
        adaptor = partial(BertForGLUESimpleAdaptor, no_logits=False, no_mask = False)


        distiller = MultiTeacherDistiller(train_config = train_config,
                            distill_config = distill_config,
                            model_T = model_Ts, model_S = model_S,
                            adaptor_T=adaptor,
                            adaptor_S=adaptor)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            raise NotImplementedError
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
        callback_func = partial(predict, eval_datasets=eval_datasets, args=args)
        with distiller:
            distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
                              num_epochs=args.num_train_epochs, callback=callback_func)

    if not args.do_train and args.do_predict:
        res = predict(model_S,eval_datasets,step=0,args=args)
        print (res)
コード例 #9
0
def main():
    #parse arguments
    config.parse()
    args = config.args

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -  %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger = logging.getLogger("Main")
    #arguments check
    device, n_gpu = args_check(logger, args)
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
    if args.local_rank != -1:
        logger.warning(
            f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}"
        )

    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")
    #set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)

    forward_batch_size = int(args.train_batch_size /
                             args.gradient_accumulation_steps)
    args.forward_batch_size = forward_batch_size

    #load bert config
    bert_config_T = ElectraConfig.from_json_file(args.bert_config_file_T)
    bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
    bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true')
    bert_config_T.output_hidden_states = (args.output_encoded_layers == 'true')
    bert_config_S.num_labels = len(label2id_dict)
    bert_config_T.num_labels = len(label2id_dict)

    assert args.max_seq_length <= bert_config_T.max_position_embeddings
    assert args.max_seq_length <= bert_config_S.max_position_embeddings

    #read data
    train_examples = None
    train_dataset = None
    eval_examples = None
    eval_dataset = None
    num_train_steps = None

    tokenizer = BertTokenizer(vocab_file=args.vocab_file,
                              do_lower_case=args.do_lower_case)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    if args.do_train:
        train_examples, train_dataset = read_features(
            args.train_file, max_seq_length=args.max_seq_length)
    if args.do_predict:
        eval_examples, eval_dataset = read_features(
            args.predict_file, max_seq_length=args.max_seq_length)

    if args.local_rank == 0:
        torch.distributed.barrier()
    #Build Model and load checkpoint
    model_T = ElectraForTokenClassification(bert_config_T)
    model_S = ElectraForTokenClassification(bert_config_S)
    #Load teacher
    if args.tuned_checkpoint_T is not None:
        state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu')
        model_T.load_state_dict(state_dict_T)
        model_T.eval()
    else:
        assert args.do_predict is True
    #Load student
    if args.load_model_type == 'bert':
        assert args.init_checkpoint_S is not None
        state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
        missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,
                                                                strict=False)
        logger.info(f"missing keys:{missing_keys}")
        logger.info(f"unexpected keys:{unexpected_keys}")
    elif args.load_model_type == 'all':
        assert args.tuned_checkpoint_S is not None
        state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu')
        model_S.load_state_dict(state_dict_S)
    else:
        logger.info("Model is randomly initialized.")
    model_T.to(device)
    model_S.to(device)

    if args.do_train:
        #parameters
        if args.lr_decay is not None:
            outputs_params = list(model_S.classifier.named_parameters())
            outputs_params = divide_parameters(outputs_params,
                                               lr=args.learning_rate)

            electra_params = []
            n_layers = len(model_S.electra.encoder.layer)
            assert n_layers == 12
            for i, n in enumerate(reversed(range(n_layers))):
                encoder_params = list(
                    model_S.electra.encoder.layer[n].named_parameters())
                lr = args.learning_rate * args.lr_decay**(i + 1)
                electra_params += divide_parameters(encoder_params, lr=lr)
                logger.info(f"{i},{n},{lr}")
            embed_params = [
                (name, value)
                for name, value in model_S.electra.named_parameters()
                if 'embedding' in name
            ]
            logger.info(f"{[name for name,value in embed_params]}")
            lr = args.learning_rate * args.lr_decay**(n_layers + 1)
            electra_params += divide_parameters(embed_params, lr=lr)
            logger.info(f"embed lr:{lr}")
            all_trainable_params = outputs_params + electra_params
            assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\
                (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters())))
        else:
            params = list(model_S.named_parameters())
            all_trainable_params = divide_parameters(params,
                                                     lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.forward_batch_size,
                                      drop_last=True)

        num_train_steps = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        optimizer = AdamW(all_trainable_params,
                          lr=args.learning_rate,
                          correct_bias=False)
        if args.official_schedule == 'const':
            scheduler_class = get_constant_schedule_with_warmup
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps)
            }
            #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps))
        elif args.official_schedule == 'linear':
            scheduler_class = get_linear_schedule_with_warmup
            scheduler_args = {
                'num_warmup_steps':
                int(args.warmup_proportion * num_train_steps),
                'num_training_steps': num_train_steps
            }
            #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps)
        else:
            raise NotImplementedError

        logger.warning("***** Running training *****")
        logger.warning("local_rank %d Num orig examples = %d", args.local_rank,
                       len(train_examples))
        logger.warning("local_rank %d Num split examples = %d",
                       args.local_rank, len(train_dataset))
        logger.warning("local_rank %d Forward batch size = %d",
                       args.local_rank, forward_batch_size)
        logger.warning("local_rank %d Num backward steps = %d",
                       args.local_rank, num_train_steps)

        ########### DISTILLATION ###########
        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device,
            fp16=args.fp16,
            local_rank=args.local_rank)
        logger.info(f"{train_config}")

        from matches import matches
        intermediate_matches = None
        if isinstance(args.matches, (list, tuple)):
            intermediate_matches = []
            for match in args.matches:
                intermediate_matches += matches[match]
        logger.info(f"{intermediate_matches}")
        distill_config = DistillationConfig(
            temperature=args.temperature,
            intermediate_matches=intermediate_matches)

        adaptor_T = ElectraForTokenClassificationAdaptor
        adaptor_S = ElectraForTokenClassificationAdaptor

        distiller = GeneralDistiller(train_config=train_config,
                                     distill_config=distill_config,
                                     model_T=model_T,
                                     model_S=model_S,
                                     adaptor_T=adaptor_T,
                                     adaptor_S=adaptor_S)

        # evluate the model in a single process in ddp_predict
        callback_func = partial(ddp_predict,
                                eval_examples=eval_examples,
                                eval_dataset=eval_dataset,
                                args=args)
        with distiller:
            distiller.train(optimizer,
                            scheduler_class=scheduler_class,
                            scheduler_args=scheduler_args,
                            max_grad_norm=1.0,
                            dataloader=train_dataloader,
                            num_epochs=args.num_train_epochs,
                            callback=callback_func)

    if not args.do_train and args.do_predict:
        res = ddp_predict(model_S,
                          eval_examples,
                          eval_dataset,
                          step=0,
                          args=args)
        print(res)