Beispiel #1
0
    def _load_bert(self, bert_config_path: str, bert_model_path: str):
        bert_config = BertConfig.from_json_file(bert_config_path)
        model = BertModel(bert_config)
        if self.cuda:
            model_states = torch.load(bert_model_path)
        else:
            model_states = torch.load(bert_model_path, map_location='cpu')
        # fix model_states
        for k in list(model_states.keys()):
            if k.startswith("bert."):
                model_states[k[5:]] = model_states.pop(k)
            elif k.startswith("cls"):
                _ = model_states.pop(k)

            if k[-4:] == "beta":
                model_states[k[:-4]+"bias"] = model_states.pop(k)
            if k[-5:] == "gamma":
                model_states[k[:-5]+"weight"] = model_states.pop(k)

        model.load_state_dict(model_states)
        if self.cuda:
            model.cuda()
        model.eval()
        return model
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()

    # 1. 训练和测试数据路径
    parser.add_argument("--data_dir",
                        default='./data/cluener',
                        type=str,
                        help="Path to data.")
    parser.add_argument("--type_description",
                        default='./data/cluener/type_des.json',
                        type=str,
                        help="Path to data.")

    # 2. 预训练模型路径
    parser.add_argument("--vocab_file",
                        default="./data/pretrain/vocab.txt",
                        type=str,
                        help="Init vocab to resume training from.")
    parser.add_argument("--config_path",
                        default="./data/pretrain/config.json",
                        type=str,
                        help="Init config to resume training from.")
    parser.add_argument("--init_checkpoint",
                        default="./data/pretrain/pytorch_model.bin",
                        type=str,
                        help="Init checkpoint to resume training from.")

    # 3. 保存模型
    parser.add_argument("--save_path",
                        default="./check_points/",
                        type=str,
                        help="Path to save checkpoints.")
    parser.add_argument("--load_path",
                        default=None,
                        type=str,
                        help="Path to load checkpoints.")

    # 训练和测试参数
    parser.add_argument("--do_train",
                        default=True,
                        type=bool,
                        help="Whether to perform training.")
    parser.add_argument("--do_eval",
                        default=True,
                        type=bool,
                        help="Whether to perform evaluation on test data set.")
    parser.add_argument("--do_predict",
                        default=False,
                        type=bool,
                        help="Whether to perform evaluation on test data set.")
    parser.add_argument("--do_adv", default=True, type=bool)

    parser.add_argument("--epochs",
                        default=10,
                        type=int,
                        help="Number of epoches for fine-tuning.")
    parser.add_argument("--train_batch_size",
                        default=8,
                        type=int,
                        help="Total examples' number in batch for training.")
    parser.add_argument("--eval_batch_size",
                        default=1,
                        type=int,
                        help="Total examples' number in batch for eval.")
    parser.add_argument("--max_seq_len",
                        default=300,
                        type=int,
                        help="Number of words of the longest seqence.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="Learning rate used to train with warmup.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.01,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10% of training.")

    parser.add_argument("--use_cuda",
                        type=bool,
                        default=True,
                        help="whether to use cuda")
    parser.add_argument("--log_steps",
                        type=int,
                        default=20,
                        help="The steps interval to print loss.")
    parser.add_argument("--eval_step",
                        type=int,
                        default=1000,
                        help="The steps interval to print loss.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    if args.use_cuda:
        device = torch.device("cuda")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cpu")
        n_gpu = 0
    logger.info("device: {}, n_gpu: {}".format(device, n_gpu))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    model_path_postfix = ''
    if args.do_adv:
        model_path_postfix += '_adv'

    args.save_path = os.path.join(args.save_path, 'ner' + model_path_postfix)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    bert_tokenizer = util.CNerTokenizer.from_pretrained(args.vocab_file)
    bert_config = BertConfig.from_pretrained(args.config_path)

    type2description = json.load(open(args.type_description))

    # 获取数据
    train_dataset = None
    eval_dataset = None
    if args.do_train:
        logger.info("loading train dataset")
        train_dataset = data_helper.NER_dataset(
            os.path.join(args.data_dir, 'train.json'), bert_tokenizer,
            args.max_seq_len, type2description)

    if args.do_eval:
        logger.info("loading eval dataset")
        eval_dataset = data_helper.NER_dataset(os.path.join(
            args.data_dir, 'dev.json'),
                                               bert_tokenizer,
                                               args.max_seq_len,
                                               type2description,
                                               shuffle=False)

    if args.do_predict:
        logger.info("loading test dataset")
        test_dataset = data_helper.NER_dataset(os.path.join(
            args.data_dir, 'test.json'),
                                               bert_tokenizer,
                                               args.max_seq_len,
                                               type2description,
                                               shuffle=False)

    if args.do_train:
        logging.info("Start training !")
        train_helper.train(bert_tokenizer, bert_config, args, train_dataset,
                           eval_dataset)

    if not args.do_train and args.do_eval:
        logging.info("Start evaluating !")
        bert_model = BertModel(config=bert_config)
        span_model = span_type.EntitySpan(config=bert_config)

        state = torch.load(args.load_path)
        bert_model.load_state_dict(state['bert_state_dict'])
        span_model.load_state_dict(state['span_state_dict'])
        logging.info("Checkpoint: %s have been loaded!" % (args.load_path))

        if args.use_cuda:
            bert_model.cuda()
            span_model.cuda()
        model_list = [bert_model, span_model]
        train_helper.evaluate(args, eval_dataset, model_list)

    if args.do_predict:
        logging.info("Start predicting !")
        bert_model = BertModel(config=bert_config)
        span_model = span_type.EntitySpan(config=bert_config)

        state = torch.load(args.load_path)
        bert_model.load_state_dict(state['bert_state_dict'])
        span_model.load_state_dict(state['span_state_dict'])
        logging.info("Checkpoint: %s have been loaded!" % (args.load_path))

        if args.use_cuda:
            bert_model.cuda()
            span_model.cuda()

        model_list = [bert_model, span_model]
        predict_res = train_helper.predict(args, test_dataset, model_list)
Beispiel #3
0
 def __init__(self, dataset, tokenizer: BertTokenizerFast, model: BertModel, batch_size=32):
   self.dataset = dataset
   self.tokenizer = tokenizer
   self.model = model.cuda()
   self.batch_size = batch_size
Beispiel #4
0
class NERPredict(IPredict):
    '''
    构造函数, 初始化预测器
    use_gpu: 使用GPU
    bert_config_file_name: Bert模型配置文件路径
    vocab_file_name: 单词表文件路径
    tags_file_name: Tag表文件路径
    bert_model_path: Bert模型装载路径
    lstm_crf_model_path: CRF模型装载路径
    hidden_dim: CRF隐藏层
    '''
    def __init__(self, use_gpu, bert_config_file_name, vocab_file_name,
                 tags_file_name, bert_model_path, lstm_crf_model_path,
                 hidden_dim):
        self.use_gpu = use_gpu
        self.data_manager_init(vocab_file_name, tags_file_name)
        self.tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
        self.model_init(hidden_dim, bert_config_file_name, bert_model_path,
                        lstm_crf_model_path)

    def data_manager_init(self, vocab_file_name, tags_file_name):
        tags_list = BERTDataManager.ReadTagsList(tags_file_name)
        tags_list = [tags_list]
        self.dm = BERTDataManager(tags_list=tags_list,
                                  vocab_file_name=vocab_file_name)

    def model_init(self, hidden_dim, bert_config_file_name, bert_model_path,
                   lstm_crf_model_path):
        config = BertConfig.from_json_file(bert_config_file_name)

        self.model = BertModel(config)

        bert_dict = torch.load(bert_model_path).module.state_dict()

        self.model.load_state_dict(bert_dict)
        self.birnncrf = torch.load(lstm_crf_model_path)

        self.model.eval()
        self.birnncrf.eval()

    def data_process(self, sentences):
        result = []
        pad_tag = '[PAD]'
        if type(sentences) == str:
            sentences = [sentences]
        max_len = 0
        for sentence in sentences:
            encode = self.tokenizer.encode(sentence, add_special_tokens=True)
            result.append(encode)
            if max_len < len(encode):
                max_len = len(encode)

        for i, sentence in enumerate(result):
            remain = max_len - len(sentence)
            for _ in range(remain):
                result[i].append(self.dm.wordToIdx(pad_tag))
        return torch.tensor(result)

    def pred(self, sentences):
        sentences = self.data_process(sentences)

        if torch.cuda.is_available() and self.use_gpu:
            self.model.cuda()
            self.birnncrf.cuda()
            sentences = sentences.cuda()

        outputs = self.model(input_ids=sentences,
                             attention_mask=sentences.gt(0))
        hidden_states = outputs[0]
        scores, tags = self.birnncrf(hidden_states, sentences.gt(0))
        final_tags = []
        decode_sentences = []

        for item in tags:
            final_tags.append([self.dm.idx_to_tag[tag] for tag in item])

        for item in sentences.tolist():
            decode_sentences.append(self.tokenizer.decode(item))

        return (scores, tags, final_tags, decode_sentences)

    def __call__(self, sentences):
        return self.pred(sentences)
Beispiel #5
0
    # concatenated vocabulary
    config.vocab_size = 50155
    print('bert configs:')
    print(config)

    # projection layer
    # for aligning context and response features
    projection = nn.Sequential(
        nn.Linear(config.hidden_size, config.hidden_size), nn.LeakyReLU(),
        nn.Linear(config.hidden_size, config.hidden_size), nn.LeakyReLU(),
        nn.Linear(config.hidden_size, config.hidden_size))

    # get bert(pretrained)
    model = BertModel(config=config)

    model.cuda()
    projection.cuda()

    model = nn.DataParallel(model)
    projection = nn.DataParallel(projection)

    # for fine tuning only projection layer
    # for params in model.parameters():
    #     params.requires_grad = False

    # get paths for model weights store/load
    if args.exp_name is not None:
        args.checkpoint_dir = '%s/%s/%s' % (SAVE_DIR, args.dataset,
                                            args.exp_name)
    else:
        args.checkpoint_dir = '%s/%s/%s' % (SAVE_DIR, args.dataset,