def __init__(self, path):
     # 一般init函数是加载所有数据
     super(SelfDataset, self).__init__()
     # 读原始数据
     self.sents_src, self.sents_tgt = read_corpus(path)
     self.word2idx = load_chinese_base_vocab()
     self.idx2word = {k: v for v, k in self.word2idx.items()}
     self.tokenizer = Tokenizer(self.word2idx)
 def __init__(self, model_path, is_cuda):
     self.word2idx = load_chinese_base_vocab()
     self.config = BertConfig(len(self.word2idx))
     self.bert_seq2seq = Seq2SeqModel(self.config)
     self.is_cuda = is_cuda
     if is_cuda:
         device = torch.device("cuda")
         self.bert_seq2seq.load_state_dict(torch.load(model_path))
         self.bert_seq2seq.to(device)
     else:
         checkpoint = torch.load(model_path,
                                 map_location=torch.device("cpu"))
         self.bert_seq2seq.load_state_dict(checkpoint)
     # 加载state dict参数
     self.bert_seq2seq.eval()
    def __init__(self):
        self.pretrain_model_path = os.path.join(Config.pretrained_path,
                                                'pytorch_model.bin')
        self.batch_size = Config.batch_size
        self.lr = Config.lr
        logging.info('加载字典')
        self.word2idx = load_chinese_base_vocab()
        self.tokenizer = Tokenizer(self.word2idx)
        # 判断是否有可用GPU
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        logging.info('using device:{}'.format(self.device))
        # 定义模型超参数
        bertconfig = BertConfig(vocab_size=len(self.word2idx))
        logging.info('初始化BERT模型')
        self.bert_model = Seq2SeqModel(config=bertconfig)
        logging.info('加载预训练的模型~')
        self.load_model(self.bert_model, self.pretrain_model_path)
        logging.info('将模型发送到计算设备(GPU或CPU)')
        self.bert_model.to(self.device)
        logging.info(' 声明需要优化的参数')
        self.optim_parameters = list(self.bert_model.parameters())
        self.init_optimizer(lr=self.lr)
        # 声明自定义的数据加载器

        logging.info('加载训练数据')
        train = SelfDataset(
            os.path.join(Config.root_path, 'data/generative/train.tsv'))
        logging.info('加载测试数据')
        dev = SelfDataset(
            os.path.join(Config.root_path, 'data/generative/dev.tsv'))
        self.trainloader = DataLoader(dataset=train,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      collate_fn=collate_fn)
        self.devloader = DataLoader(dataset=dev,
                                    batch_size=self.batch_size,
                                    shuffle=True,
                                    collate_fn=collate_fn)
def main():

    config_distil.parse()
    global args
    args = config_distil.args
    global logger
    logger = create_logger(args.log_file)

    for k, v in vars(args).items():
        logger.info(f"{k}:{v}")

    # set seeds
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    logger.info('加载字典')
    word2idx = load_chinese_base_vocab()
    # 判断是否有可用GPU
    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and args.is_cuda else "cpu")

    logger.info('using device:{}'.format(args.device))
    # 定义模型超参数
    bertconfig_T = BertConfig(vocab_size=len(word2idx))
    bertconfig_S = BertConfig(vocab_size=len(word2idx), num_hidden_layers=3)
    logger.info('初始化BERT模型')
    bert_model_T = Seq2SeqModel(config=bertconfig_T)
    bert_model_S = Seq2SeqModel(config=bertconfig_S)
    logger.info('加载Teacher模型~')
    load_model(bert_model_T, args.tuned_checkpoint_T)
    logger.info('将模型发送到计算设备(GPU或CPU)')
    bert_model_T.to(args.device)
    bert_model_T.eval()

    logger.info('加载Student模型~')
    if args.load_model_type == 'bert':
        load_model(bert_model_S, args.init_checkpoint_S)
    else:
        logger.info(" Student Model is randomly initialized.")
    logger.info('将模型发送到计算设备(GPU或CPU)')
    bert_model_S.to(args.device)
    # 声明自定义的数据加载器

    logger.info('加载训练数据')
    train = SelfDataset(args.train_path, args.max_length)
    trainloader = DataLoader(train,
                             batch_size=args.train_batch_size,
                             shuffle=True,
                             collate_fn=collate_fn)

    if args.do_train:

        logger.info(' 声明需要优化的参数')
        num_train_steps = int(
            len(trainloader) / args.train_batch_size) * args.num_train_epochs
        optim_parameters = list(bert_model_S.named_parameters())
        all_trainable_params = divide_parameters(optim_parameters,
                                                 lr=args.learning_rate)
        logger.info("Length of all_trainable_params: %d",
                    len(all_trainable_params))
        optimizer = BERTAdam(all_trainable_params,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps,
                             schedule=args.schedule,
                             s_opt1=args.s_opt1,
                             s_opt2=args.s_opt2,
                             s_opt3=args.s_opt3)

        train_config = TrainingConfig(
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            ckpt_frequency=args.ckpt_frequency,
            log_dir=args.output_dir,
            output_dir=args.output_dir,
            device=args.device)

        from generative.matches import matches
        intermediate_matches = None

        if isinstance(args.matches, (list, tuple)):
            intermediate_matches = []
            for match in args.matches:
                intermediate_matches += matches[match]
        intermediate_matches = []
        for match in args.matches:
            intermediate_matches += matches[match]

        logger.info(f"{intermediate_matches}")
        distill_config = DistillationConfig(
            temperature=args.temperature,
            intermediate_matches=intermediate_matches)

        def BertForS2SSimpleAdaptor(batch, model_outputs):
            return {'hidden': model_outputs[0], 'logits': model_outputs[1], 'loss': model_outputs[2], 'attention': model_outputs[3]}

        adaptor_T = partial(BertForS2SSimpleAdaptor)
        adaptor_S = partial(BertForS2SSimpleAdaptor)

        distiller = GeneralDistiller(train_config=train_config,
                                     distill_config=distill_config,
                                     model_T=bert_model_T,
                                     model_S=bert_model_S,
                                     adaptor_T=adaptor_T,
                                     adaptor_S=adaptor_S)
        callback_func = partial(predict, data_path=args.dev_path, args=args)
        logger.info('Start distillation.')
        with distiller:
            distiller.train(optimizer,
                            scheduler=None,
                            dataloader=trainloader,
                            num_epochs=args.num_train_epochs,
                            callback=None)

    if not args.do_train and args.do_predict:
        res = predict(bert_model_S, args.test_path, step=0, args=args)
        print(res)