def main():
    args = setup_train_args()
    # 初始化tokenizer
    tokenizer = BertTokenizer(vocab_file=args.vocab_path)
    tokenizer.add_special_tokens({'additional_special_tokens':['#E-s', '#crumb-wrap', '<coupon>', '<url>', '<img>', '<S>', '<E>']})
    # tokenizer的字典大小
    vocab_size = len(tokenizer)
    if args.train_mmi:  # 如果当前是要训练MMI模型
        preprocess_mmi_raw_data(args, tokenizer)
    else:  # 如果当前是要训练对话生成模型
        preprocess_raw_data(args, tokenizer)
Exemple #2
0
def main():
    args = setup_train_args()
    # 日志同时输出到文件和console
    global logger
    logger = create_logger(args)
    # 当用户使用GPU,并且GPU可用时
    args.cuda = torch.cuda.is_available() and not args.no_cuda
    device = 'cuda' if args.cuda else 'cpu'
    logger.info('using device:{}'.format(device))
    # 为CPU设置种子用于生成随机数,以使得结果是确定的
    # 为当前GPU设置随机种子;如果使用多个GPU,应该使用torch.cuda.manual_seed_all()为所有的GPU设置种子。
    # 当得到比较好的结果时我们通常希望这个结果是可以复现
    if args.seed:
        set_random_seed(args)

    # 设置使用哪些显卡进行训练
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device

    # 初始化tokenizer
    tokenizer = BertTokenizer(vocab_file=args.vocab_path)
    tokenizer.add_special_tokens({
        'additional_special_tokens':
        ['#E-s', '#crumb-wrap', '<coupon>', '<url>', '<img>', '<S>', '<E>']
    })
    # tokenizer的字典大小
    vocab_size = len(tokenizer)

    global pad_id
    pad_id = tokenizer.convert_tokens_to_ids(PAD)

    # 创建对话模型的输出目录
    if not os.path.exists(args.dialogue_model_output_path):
        os.mkdir(args.dialogue_model_output_path)
    # 创建MMI模型的输出目录
    if not os.path.exists(args.mmi_model_output_path):
        os.mkdir(args.mmi_model_output_path)
    # 加载GPT2模型
    model, args.n_ctx = create_model(args, vocab_size)
    model.to(device)
    # 对原始数据进行预处理,将原始语料转换成对应的token_id
    if args.raw and args.train_mmi:  # 如果当前是要训练MMI模型
        preprocess_mmi_raw_data(args, tokenizer, args.n_ctx)
    elif args.raw and not args.train_mmi:  # 如果当前是要训练对话生成模型
        preprocess_raw_data(args, tokenizer, args.n_ctx)
    # 是否使用多块GPU进行并行运算
    multi_gpu = False
    if args.cuda and torch.cuda.device_count() > 1:
        logger.info("Let's use GPUs to train")
        model = DataParallel(
            model, device_ids=[int(i) for i in args.device.split(',')])
        multi_gpu = True
    # 记录模型参数数量
    num_parameters = 0
    parameters = model.parameters()
    for parameter in parameters:
        num_parameters += parameter.numel()
    logger.info('number of model parameters: {}'.format(num_parameters))

    # 加载数据
    logger.info("loading traing data")
    if args.train_mmi:  # 如果是训练MMI模型
        with open(args.train_mmi_tokenized_path, "r", encoding="utf8") as f:
            data = f.read()
    else:  # 如果是训练对话生成模型
        with open(args.train_tokenized_path, "r", encoding="utf8") as f:
            data = f.read()
    data_list = data.split("\n")
    i = 0
    while i < len(data_list):
        if len(data_list[i]) == 0:
            del data_list[i]
            i = i - 1
        i = i + 1
    data_list = data_list[:-1]
    train_list, test_list = train_test_split(data_list,
                                             test_size=0.1,
                                             random_state=1)
    # 开始训练
    train(model, device, train_list, multi_gpu, args)
    # 测试模型
    evaluate(model, device, test_list, multi_gpu, args)
class REPreprocessor(Component):
    def __init__(self,
                 vocab_file: str,
                 special_token: str = '<ENT>',
                 ner_tags=None,
                 max_seq_length: int = 512,
                 do_lower_case: bool = False,
                 default_tag: str = None,
                 **kwargs):
        """
        Args:
            vocab_file: path to vocabulary / name of vocabulary for tokenizer initialization
            special_token: an additional token that will be used for marking the entities in the document
            do_lower_case: set True if lowercasing is needed
            default_tag: used for test purposes to create a valid input
        Return:
            list of feature batches with input_ids, attention_mask, entity_pos, ner_tags
        """

        self.special_token = special_token
        self.special_tokens_dict = {
            'additional_special_tokens': [self.special_token]
        }
        self.default_tag = default_tag

        if ner_tags is None:
            ner_tags = ['ORG', 'TIME', 'MISC', 'LOC', 'PER', 'NUM']
        self.ner2id = {tag: tag_id for tag_id, tag in enumerate(ner_tags)}
        self.max_seq_length = max_seq_length

        if Path(vocab_file).is_file():
            vocab_file = str(expand_path(vocab_file))
            self.tokenizer = BertTokenizer(vocab_file=vocab_file,
                                           do_lower_case=do_lower_case)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(
                vocab_file, do_lower_case=do_lower_case)

    def __call__(
        self,
        tokens: Union[Tuple, List[List[str]]],
        entity_pos: Union[Tuple, List[List[Tuple]]],
        entity_tags: Union[Tuple, List[List[str]]],
    ) -> Tuple[List, List, List, List, List]:
        """
        Tokenize and create masks; recalculate the entity positions regarding the document boarders.
        Args:
            tokens: List of tokens of each document: List[List[tokens in doc]]
            entity_pos: start and end positions of the entities' mentions
            entity_tags: NER tag of the entities
        Return:
            input_ids: List[List[int]],
            attention_mask: List[List[int]],
            entity_poss: List[
                            List[
                                List[(entity1_mention1_start_id, entity1_mention1_end_id), ...],
                                List[(entity2_mention1_start_id, entity2_mention1_end_id), ...]
                            ]
                        ]
            entity_tags: List[List[int]]
            nf_samples: List[int] - contains the information about whether the corresponding sample is real sample or
                fake (for testing): 0 means the sample is real, 1 - it is fake.
        """

        _ = self.tokenizer.add_special_tokens(self.special_tokens_dict)

        input_ids, attention_mask, upd_entity_pos, upd_entity_tags, nf_samples = [], [], [], [], []

        # this workaround is for proper testing: for an unknown reason during test in test_quick_start.py
        # each input list is transformed into a tuple, e.g., tokens -> tuple(tokens, ).
        # todo: refactoring
        if type(tokens) == tuple and type(entity_pos) == tuple and type(
                entity_tags) == tuple:
            tokens = tokens[0]
            entity_pos = entity_pos[0]
            entity_tags = entity_tags[0]

        for n_sample, (doc, ent_pos, ent_tags) in enumerate(
                zip(tokens, entity_pos, entity_tags)):

            # valid scenario
            if isinstance(ent_pos, list) and len(ent_pos) == 2:
                count = 0
                doc_wordpiece_tokens = []

                entity1_pos_start = list(zip(
                    *ent_pos[0]))[0]  # first entity mentions' start positions
                entity1_pos_end = list(zip(
                    *ent_pos[0]))[1]  # first entity mentions' end positions
                entity2_pos_start = list(zip(
                    *ent_pos[1]))[0]  # second entity mentions' start positions
                entity2_pos_end = list(zip(
                    *ent_pos[1]))[1]  # second entity mentions' end positions

                upd_entity1_pos_start, upd_entity2_pos_start, upd_entity1_pos_end, upd_entity2_pos_end = [], [], [], []
                for n, token in enumerate(doc):
                    if n in entity1_pos_start:
                        doc_wordpiece_tokens.append(self.special_token)
                        upd_entity1_pos_start.append(count)
                        count += 1

                    if n in entity1_pos_end:
                        doc_wordpiece_tokens.append(self.special_token)
                        count += 1
                        upd_entity1_pos_end.append(count)

                    if n in entity2_pos_start:
                        doc_wordpiece_tokens.append(self.special_token)
                        upd_entity2_pos_start.append(count)
                        count += 1

                    if n in entity2_pos_end:
                        doc_wordpiece_tokens.append(self.special_token)
                        count += 1
                        upd_entity2_pos_end.append(count)

                    word_tokens = self.tokenizer.tokenize(token)
                    doc_wordpiece_tokens += word_tokens
                    count += len(word_tokens)

                # special case when the entity is the last in the doc
                if len(doc) in entity1_pos_end:
                    doc_wordpiece_tokens.append(self.special_token)
                    count += 1
                    upd_entity1_pos_end.append(count)
                if len(doc) in entity2_pos_end:
                    doc_wordpiece_tokens.append(self.special_token)
                    count += 1
                    upd_entity2_pos_end.append(count)
                    word_tokens = self.tokenizer.tokenize(token)
                    doc_wordpiece_tokens += word_tokens
                    count += len(word_tokens)

                upd_entity_1_pos = list(
                    zip(upd_entity1_pos_start, upd_entity1_pos_end))
                upd_entity_2_pos = list(
                    zip(upd_entity2_pos_start, upd_entity2_pos_end))

                # text entities for self check
                upd_entity1_text = [
                    doc_wordpiece_tokens[ent_m[0]:ent_m[1]]
                    for ent_m in upd_entity_1_pos
                ]
                upd_entity2_text = [
                    doc_wordpiece_tokens[ent_m[0]:ent_m[1]]
                    for ent_m in upd_entity_2_pos
                ]

                enc_entity_tags = self.encode_ner_tag(ent_tags)

                encoding = self.tokenizer.encode_plus(
                    doc_wordpiece_tokens[:self.
                                         max_seq_length],  # truncate tokens
                    add_special_tokens=True,
                    truncation=True,
                    max_length=self.max_seq_length,
                    pad_to_max_length=True,
                    return_attention_mask=True)
                upd_entity_pos.append([upd_entity_1_pos, upd_entity_2_pos])
                nf_samples.append(0)

            # api test scenario
            else:
                # for api test: dump values of entity tags and entity pos
                encoding = self.tokenizer.encode_plus(
                    doc,
                    add_special_tokens=True,
                    truncation=True,
                    max_length=self.max_seq_length,
                    pad_to_max_length=True,
                    return_attention_mask=True)
                upd_entity_pos.append([[(0, 1)], [(0, 1)]])
                enc_entity_tags = self.encode_ner_tag([self.default_tag] * 2)
                nf_samples.append(1)

            input_ids.append(encoding['input_ids'])
            attention_mask.append(encoding['attention_mask'])
            upd_entity_tags.append(enc_entity_tags)

        return input_ids, attention_mask, upd_entity_pos, upd_entity_tags, nf_samples

    def encode_ner_tag(self, ner_tags: List) -> List:
        """ Encode NER tags with one hot encodings """
        enc_ner_tags = []
        for ner_tag in ner_tags:
            ner_tag_one_hot = [0] * len(self.ner2id)
            ner_tag_one_hot[self.ner2id[ner_tag]] = 1
            enc_ner_tags.append(ner_tag_one_hot)
        return enc_ner_tags
Exemple #4
0
class BuildCustomTransformersVocabulary(object):
    def __init__(self,
                 base_vocab_path='./vocab_small.txt',
                 additional_special_tokens={
                     'additional_special_tokens':
                     ['<num>', '<img>', '<url>', '#E-s', '|||']
                 }):
        self.tokenizer = BertTokenizer(vocab_file=base_vocab_path,
                                       do_lower_case=False,
                                       do_basic_tokenize=True)
        self.tokenizer.add_special_tokens(additional_special_tokens)
        self.no_vocab_tokens = set()

    def get_no_vocab_token(self, text, unk_token='[UNK]', other_split=False):
        """ tokens compare
        @param text:
        @param unk_token:
        @param other_split:  原始拆分出来single token txt, bert tokenizer拆分之后依然拆解为多个token, 是否增加词汇
        @return:
        """
        # text_tokens = self.tokenizer.tokenize(text)  # bert tokenizer根据词汇表处理之后切分出来的token(包含unk)
        origin_tokens = self.tokenize(text)  # 切词之后结果, 不在词汇表中的词没有转为unk

        # # 第一种方法不能保证一一对应, 有些切分出来字符再次转换时候会被再次切分
        # assert len(text_tokens) == len(origin_tokens)
        for idx, token in enumerate(origin_tokens):
            # 使用transformer tokenizer根据基础词汇表转换
            bert_token = self.tokenizer.tokenize(token)
            # if token != origin_tokens[idx]:
            #     # 未知token添加进词汇表
            #     self.no_vocab_tokens.append(origin_tokens[idx])
            if len(bert_token) == 1 and bert_token[0] == unk_token:
                self.no_vocab_tokens.add(token)  # 借助set去重
            if other_split and len(bert_token) > 1:
                # 单个字符被bert tokenizer拆分为多个字符, 实际不需要拆分
                self.no_vocab_tokens.add(token)

    def _tokenize(self, text):
        """将text拆分为 token list"""
        tokens_list = self.tokenizer.basic_tokenizer.tokenize(
            text, never_split=self.tokenizer.all_special_tokens)

        return tokens_list

    def tokenize(self, text: str, **kwargs):
        """ Converts a string in a sequence of tokens (string), using the tokenizer.
            Split in words for word-based vocabulary or sub-words for sub-word-based
            vocabularies (BPE/SentencePieces/WordPieces).

            Take care of added tokens.

            Args:
                text (:obj:`string`): The sequence to be encoded.
                **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
        """
        all_special_tokens = self.tokenizer.all_special_tokens
        text = self.tokenizer.prepare_for_tokenization(text, **kwargs)

        # TODO: should this be in the base class?
        def lowercase_text(t):
            # convert non-special tokens to lowercase
            escaped_special_toks = [
                re.escape(s_tok) for s_tok in all_special_tokens
            ]
            pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
            return re.sub(pattern,
                          lambda m: m.groups()[0] or m.groups()[1].lower(), t)

        if self.tokenizer.init_kwargs.get("do_lower_case", False):
            text = lowercase_text(text)

        def split_on_token(tok, text):
            result = []
            split_text = text.split(tok)
            for i, sub_text in enumerate(split_text):
                sub_text = sub_text.rstrip()
                if i == 0 and not sub_text:
                    result += [tok]
                elif i == len(split_text) - 1:
                    if sub_text:
                        result += [sub_text]
                    else:
                        pass
                else:
                    if sub_text:
                        result += [sub_text]
                    result += [tok]
            return result

        def split_on_tokens(tok_list, text):
            if not text.strip():
                return []
            if not tok_list:
                return self._tokenize(text)

            tokenized_text = []
            text_list = [text]
            for tok in tok_list:
                tokenized_text = []
                for sub_text in text_list:
                    if sub_text not in self.tokenizer.unique_added_tokens_encoder:
                        tokenized_text += split_on_token(tok, sub_text)
                    else:
                        tokenized_text += [sub_text]
                text_list = tokenized_text

            return list(
                itertools.chain.from_iterable(
                    (self._tokenize(token)
                     if token not in self.tokenizer.unique_added_tokens_encoder
                     else [token] for token in tokenized_text)))

        added_tokens = self.tokenizer.unique_added_tokens_encoder
        tokenized_text = split_on_tokens(added_tokens, text)
        return tokenized_text

    def update_vocab(self, new_vocab_tokens: list):
        """ 更新原有基础词汇表
        @param new_vocab_tokens:
        @return:
        """
        add_token_num = self.tokenizer.add_tokens(new_vocab_tokens)

        return add_token_num

    def custom_save_vocabulary(self, new_vocab_path):
        """保存新的词汇表"""
        if os.path.exists(new_vocab_path):
            os.remove(new_vocab_path)

        index = 0
        with open(new_vocab_path, mode='w', encoding='utf-8') as writer:
            for token, token_index in sorted(self.tokenizer.vocab.items(),
                                             key=lambda kv: kv[1]):
                if index != token_index:
                    print(
                        "Saving vocabulary to {}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!".
                        format(new_vocab_path))
                    index = token_index
                writer.write(token + "\n")
                index += 1

            # 将新增加的token添加到词汇表
            add_tokens_vocab = OrderedDict(self.tokenizer.added_tokens_encoder)
            for token, token_index in sorted(add_tokens_vocab.items(),
                                             key=lambda kv: kv[1]):
                if index != token_index:
                    print(
                        "Saving vocabulary to {}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!".
                        format(new_vocab_path))
                    index = token_index
                writer.write(token + "\n")
                index += 1

        return new_vocab_path

    def save_vocab_pretrained(self, vocab_pretrained_path):
        """保存词表预训练全部内容"""
        if not os.path.exists(vocab_pretrained_path):
            # 路径不存在, 创建路径
            os.makedirs(vocab_pretrained_path)
        all_file = self.tokenizer.save_pretrained(
            vocab_pretrained_path)  # 存储所有词汇内容
        # model.resize_token_embeddings(len(tokenizer)) -> 重新设置embedding大小(词汇表大小已经改变)
        # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary,
        # i.e. the length of the tokenizer.
        return all_file
Exemple #5
0
tokenizer.SPECIAL_TOKENS_ATTRIBUTES
tokenizer.encode(y)
tokenizer.encode_plus(y)
y = "<BOS> embedding what is the flight number <EOS>"
ids = tokenizer.encode_plus
tokenizer.decode(tokenizer.encode(y))
tokenizer.save_pretrained("data/atis/save")
tokenizer.save_vocabulary("data/atis/save/saved")

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                          bos_token="<BOS>",
                                          eos_token="<EOS>")
tokenizer.tokenize("i like tea")
special_tokens = {"bos_token": "<BOS>", "eos_token": "<EOS>"}
tokenizer.add_special_tokens(special_tokens)

tokenizer.bos_token_id
tokenizer.eos_token_id
tokenizer.all_special_ids

tokenizer.special_tokens_map
tokenizer.additional_special_tokens
y = "<BOS> I like embeddings <EOS> [SEP] i like tea"
z = tokenizer.encode(y)
tokenizer.convert_ids_to_tokens(z)
tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(z))

tokenizer.encode("embeddings embedding")
tokenizer.encode("i like tea")
tokenizer.encode("i like tea")
Exemple #6
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.logger = self.get_logger('Trainer')
        self.random = random.Random(self.args.seed)

        self.device = torch.device(self.args.device_id)
        torch.cuda.set_device(self.device)
        self.logger.info(f'Use device {self.device}')

        self.prepare_stuff()

        self.tokenizer = BertTokenizer(self.args.vocab_path)
        self.tokenizer.add_special_tokens({'bos_token': '[BOS]'})

        self.model, self.optim = self.load_model()
        self.train_batches_all, self.test_batches_all = self.load_batches()

    def load_model(self):
        if self.args.pretrain_path == '' or self.args.resume:
            self.args.pretrain_path = None
        model_args = dict(
            tokenizer=self.tokenizer,
            max_decode_len=self.args.max_decode_len,
            pretrain_path=self.args.pretrain_path,
        )
        if self.args.use_keywords:
            self.logger.info('Creating KWSeq2Seq model...')
            self.model = KWSeq2Seq(**model_args)
        else:
            self.logger.info('Creating Seq2Seq model...')
            self.model = Seq2Seq(**model_args)

        self.logger.info(f'Moving model to {self.device}...')
        self.model = self.model.to(self.device)
        self.logger.info(f'Moving model to {self.device} done.')
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)

        self.model, self.optim = apex.amp.initialize(self.model,
                                                     self.optim,
                                                     opt_level='O2',
                                                     verbosity=0)

        if self.args.resume:
            self.model, self.optim = self.resume()

        if self.args.n_gpu > 1:
            torch.distributed.init_process_group(
                backend='nccl',
                init_method='tcp://127.0.0.1:29500',
                rank=self.args.rank,
                world_size=self.args.n_gpu,
            )
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.args.rank],
                output_device=self.args.rank,
                find_unused_parameters=True,
            )
        return self.model, self.optim

    def load_batches(self):
        self.train_batches_all = None
        self.test_batches_all = None

        if self.args.train_pickle_path:
            with open(self.args.train_pickle_path, 'rb') as f:
                batches = pickle.load(f)
            n, m = len(batches), self.args.n_gpu
            r = (n + m - 1) // m * m - n
            self.train_batches_all = batches + batches[:r]

        if self.args.test_pickle_path:
            with open(self.args.test_pickle_path, 'rb') as f:
                self.test_batches_all = pickle.load(f)

        return self.train_batches_all, self.test_batches_all

    def get_logger(self, class_name):
        colors = ['', '\033[92m', '\033[93m', '\033[94m']
        reset_color = '\033[0m'
        self.logger = logging.getLogger(__name__ + '.' + class_name)
        self.logger.setLevel(logging.INFO)
        hander = logging.StreamHandler()
        hander.setLevel(logging.INFO)
        s = f'[%(levelname)s] [{class_name}.%(funcName)s] %(message)s'
        if self.args.n_gpu > 1:
            s = f'[Rank {self.args.rank}] ' + s
            s = f' {colors[self.args.rank % 4]}' + s + reset_color
        formatter = logging.Formatter(s)
        hander.setFormatter(formatter)
        self.logger.addHandler(hander)
        return self.logger

    def prepare_stuff(self):
        # Recoder
        self.recoder = Recoder(tag=self.args.tag,
                               clear=self.args.clear,
                               uri=self.args.mongodb_uri,
                               db=self.args.mongodb_db)

        # Tensorboard
        if self.args.is_worker:
            self.tensorboard_dir = os.path.join(self.args.tensorboard_base_dir,
                                                self.args.tag)
            if self.args.clear and os.path.exists(self.tensorboard_dir):
                shutil.rmtree(self.tensorboard_dir)
                self.logger.info(f'Clear "{self.tensorboard_dir}"')
                time.sleep(1)
            self.writer = SummaryWriter(self.tensorboard_dir)

        # Checkpoint
        self.checkpoints_dir = os.path.join(self.args.checkpoints_base_dir,
                                            self.args.tag)
        if self.args.is_worker:
            if self.args.resume:
                assert os.path.exists(self.checkpoints_dir)
            else:
                if self.args.clear and os.path.exists(self.checkpoints_dir):
                    shutil.rmtree(self.checkpoints_dir)
                    self.logger.info(f'Clear "{self.checkpoints_dir}"')
                os.makedirs(self.checkpoints_dir)

    def resume(self):
        ckpts = {
            int(os.path.splitext(name)[0].split('-')[-1]): name
            for name in os.listdir(self.checkpoints_dir)
        }
        assert len(ckpts) > 0
        self.epoch = max(ckpts)
        path = os.path.join(self.checkpoints_dir, ckpts[self.epoch])

        self.logger.info(f'Resume from "{path}"')
        state_dict = torch.load(path, map_location=self.device)
        self.model.load_state_dict(state_dict['model'], strict=True)
        self.optim.load_state_dict(state_dict['optim'])
        apex.amp.load_state_dict(state_dict['amp'])
        return self.model, self.optim

    def loss_fn(self, input, target):
        loss = torch.nn.functional.cross_entropy(
            input=input.reshape(-1, input.size(-1)),
            target=target[:, 1:].reshape(-1),
            ignore_index=self.tokenizer.pad_token_id,
            reduction='mean',
        )
        return loss

    def train_batch(self, batch):
        # Forward & Loss
        batch.to(self.device)
        if self.args.use_keywords:
            logits, k_logits = self.model(mode='train',
                                          x=batch.x,
                                          y=batch.y,
                                          k=batch.k)
            y_loss = self.loss_fn(input=logits, target=batch.y)
            k_loss = self.loss_fn(input=k_logits, target=batch.k)
            loss = self.args.y_loss_weight * y_loss + self.args.k_loss_weight * k_loss
        else:
            logits = self.model(mode='train', x=batch.x, y=batch.y)
            loss = self.loss_fn(input=logits, target=batch.y)

        # Backward & Optim
        loss = loss / self.args.n_accum_batches
        with apex.amp.scale_loss(loss, self.optim) as scaled_loss:
            scaled_loss.backward()
        self.train_steps += 1
        if self.train_steps % self.args.n_accum_batches == 0:
            self.optim.step()
            self.model.zero_grad()

        # Log
        if self.args.is_worker:
            self.writer.add_scalar('_Loss/all', loss.item(), self.train_steps)
            if self.args.use_keywords:
                self.writer.add_scalar('_Loss/y_loss', y_loss.item(),
                                       self.train_steps)
                self.writer.add_scalar('_Loss/k_loss', k_loss.item(),
                                       self.train_steps)

        if self.train_steps % self.args.case_interval == 0:
            y_pred_ids = logits.argmax(dim=-1).tolist()
            y_pred = self.batch_ids_to_strings(y_pred_ids)
            if self.args.use_keywords:
                k_pred_ids = k_logits.argmax(dim=-1).tolist()
                k_pred = self.batch_ids_to_tokens(k_pred_ids)
            else:
                k_pred = None
            self.recoder.record(
                mode='train',
                epoch=self.epoch,
                step=self.train_steps,
                rank=self.args.rank,
                texts_x=batch.texts_x[0],
                text_y=batch.text_y[0],
                y_pred=y_pred[0],
                tokens_k=batch.tokens_k[0] if 'tokens_k' in batch else None,
                k_pred=k_pred[0] if k_pred is not None else None,
            )
        return loss.item()

    def train_epoch(self):
        self.model.train()
        self.model.zero_grad()
        if self.args.is_worker:
            pbar = tqdm(
                self.train_batches,
                desc=f'[{self.args.tag}] [{self.device}] Train {self.epoch}',
                dynamic_ncols=True)
        else:
            pbar = self.train_batches
        for batch in pbar:
            loss = self.train_batch(batch)  # Core training
            if self.args.is_worker:
                pbar.set_postfix({'loss': f'{loss:.4f}'})

        # Checkpoint
        if self.args.is_worker:
            if isinstance(self.model,
                          torch.nn.parallel.DistributedDataParallel):
                model_state_dict = self.model.module.state_dict()
            else:
                model_state_dict = self.model.state_dict()
            state_dict = {
                'model': model_state_dict,
                'optim': self.optim.state_dict(),
                'amp': apex.amp.state_dict(),
            }
            path = os.path.join(self.checkpoints_dir,
                                f'{self.args.tag}-epoch-{self.epoch}.pt')
            torch.save(state_dict, path)

    def fit(self):
        if not self.args.resume:
            self.epoch = 0
            self.train_steps, self.test_steps = 0, 0
        else:
            if self.train_batches_all is not None:
                self.train_steps = self.epoch * len(
                    self.train_batches_all) // self.args.n_gpu
            if self.test_batches_all is not None:
                self.test_steps = self.epoch * len(
                    self.test_batches_all) // self.args.n_gpu

        if self.test_batches_all is not None:
            self.test_batches = self.test_batches_all[self.args.rank::self.
                                                      args.n_gpu]
            if self.args.is_worker:
                self.recoder.record_target(self.test_batches_all)

        while True:
            self.epoch += 1

            if self.train_batches_all is not None:
                self.random.shuffle(self.train_batches_all)
                self.train_batches = self.train_batches_all[
                    self.args.rank::self.args.n_gpu]
                if self.args.n_gpu > 1:
                    torch.distributed.barrier()
                self.train_epoch()

            if self.test_batches_all is not None:
                if self.args.n_gpu > 1:
                    torch.distributed.barrier()
                results = self.test_epoch()
                self.recoder.record_output(results)

    def test_epoch(self):
        self.model.eval()
        results = []
        if self.args.is_worker:
            pbar = tqdm(
                self.test_batches,
                desc=f'[{self.args.tag}] [{self.device}] Test {self.epoch}',
                dynamic_ncols=True)
        else:
            pbar = self.test_batches
        for batch in pbar:
            batch_results = self.test_batch(batch)  # Core testing
            results += batch_results
            if self.args.is_worker:
                pbar.set_postfix({'step': self.test_steps})
        return results

    def test_batch(self, batch):
        batch.to(self.device)
        results = []
        if self.args.use_keywords:
            y_pred_ids, k_pred_ids = self.model(mode='test', x=batch.x)
            y_pred = self.batch_ids_to_strings(y_pred_ids.tolist())
            k_pred = self.batch_ids_to_tokens(k_pred_ids.tolist())
            for i, y, k in zip(batch.index, y_pred, k_pred):
                results.append({
                    'epoch': self.epoch,
                    'index': i,
                    'rank': self.args.rank,
                    'y': y,
                    'k': k
                })
        else:
            y_pred_ids = self.model(mode='test', x=batch.x)
            y_pred = self.batch_ids_to_strings(y_pred_ids.tolist())
            for i, y in zip(batch.index, y_pred):
                results.append({
                    'epoch': self.epoch,
                    'index': i,
                    'rank': self.args.rank,
                    'y': y
                })
            k_pred = None

        self.test_steps += 1
        if self.test_steps % self.args.case_interval == 0:
            self.recoder.record(
                mode='test',
                epoch=self.epoch,
                step=self.test_steps,
                rank=self.args.rank,
                texts_x=batch.texts_x[0],
                text_y=batch.text_y[0],
                y_pred=y_pred[0],
                tokens_k=batch.tokens_k[0] if 'tokens_k' in batch else None,
                k_pred=k_pred[0] if k_pred is not None else None,
            )
        return results

    def batch_ids_to_strings(self, batch_ids):
        strings = []
        for ids in batch_ids:
            tokens = self.ids_to_tokens(ids)
            string = self.tokenizer.convert_tokens_to_string(tokens)
            strings.append(string)
        return strings

    def batch_ids_to_tokens(self, batch_ids):
        return [self.ids_to_tokens(ids) for ids in batch_ids]

    def ids_to_tokens(self, ids):
        tokens = self.tokenizer.convert_ids_to_tokens(ids)
        if tokens[0] == self.tokenizer.bos_token:
            tokens = tokens[1:]
        if tokens.count(self.tokenizer.sep_token) > 0:
            sep_pos = tokens.index(self.tokenizer.sep_token)
            tokens = tokens[:sep_pos]
        return tokens
Exemple #7
0
    - clinical BERT is used for the WordPiece vocabulary
    """
    with open('data/vocab.pk', 'rb') as fd:
        vocab = pickle.load(fd)

    split_pt = min(vocab.section_start_vocab_id, vocab.category_start_vocab_id)
    metadata_tokens = vocab.i2w[split_pt:] + ['digitparsed']

    with open('data/clinic_bert_vocab.txt', 'r') as fd:
        clinic_bert_tokens = fd.readlines()

    clinic_bert_tokens = list(set(list(map(_standardize, clinic_bert_tokens))))
    vocab_fn = 'data/clinic_bert_plus_metadata_vocab.txt'
    with open(vocab_fn, 'w') as fd:
        fd.write('\n'.join(clinic_bert_tokens + metadata_tokens))
    tokenizer = BertTokenizer(vocab_fn, never_split=metadata_tokens, do_basic_tokenize=False)

    # Add metadata as `additional_special_tokens` so that they do not get subdivided into word pieces
    special_tokens_dict = {'cls_token': '[CLS]', 'sep_token': '[SEP]', 'unk_token': '[UNK]', 'bos_token': '[BOS]',
                           'eos_token': '[EOS]', 'pad_token': '[PAD]', 'mask_token': '[MASK]',
                           'additional_special_tokens': metadata_tokens}

    num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens')

    example_sentence = 'header=HPI example clinical text with tricky autoimmunihistory word'
    print('Tokenization of sentence: {}'.format(example_sentence))
    out_fn = 'data/bert_tokenizer_vocab.pth'
    print('Generated BERT tokenizer vocabulary and saving to {}'.format(out_fn))
    tokenizer.save_vocabulary(out_fn)
Exemple #8
0
def loadBertTokenizer(path,special_dict={}):
    tokenizer = BertTokenizer(path)
    tokenizer.add_special_tokens(special_dict)
    return tokenizer