Exemple #1
0
class Tokenizer:
    def __init__(self, tokenizer_name="komoran"):
        assert (tokenizer_name.lower() == "komoran") or (tokenizer_name.lower() == "mecab")\
            or (tokenizer_name.lower() == "subword"), "Only 'komoran', 'mecab', and 'subword' is acceptable."
        if tokenizer_name == "komoran":
            self.tokenizer = Komoran("STABLE")
        elif tokenizer_name == "mecab":
            self.tokenizer = Mecab()
        elif tokenizer_name == "subword":
            self.tokenizer = BertTokenizer(resource_filename(__package__, "vocab_noised.txt"), do_lower_case=False)
        self.tokenizer_name = tokenizer_name

    def tokenize(self, text):
        if self.tokenizer_name == "komoran":
            return self.tokenizer.get_morphes_by_tags(text)
        elif self.tokenizer_name == "mecab":
            return self.tokenizer.morphs(text)
        else: # self.tokenizer_name 이 None
            return self.tokenizer.tokenize(text)

    def post_process(self, tokens):
        if self.tokenizer_name == "komoran":
            return " ".join(tokens)
        elif self.tokenizer_name == "mecab":
            return " ".join(tokens)
        else: # self.tokenizer_name 이 subword 또는 moduletype
            return self.tokenizer.convert_tokens_to_string(tokens)
Exemple #2
0
class BertBPE(object):
    def __init__(self, cfg):
        try:
            from transformers import BertTokenizer
        except ImportError:
            raise ImportError(
                "Please install transformers with: pip install transformers")

        if cfg.bpe_vocab_file:
            self.bert_tokenizer = BertTokenizer(
                cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased)
        else:
            vocab_file_name = ("bert-base-cased"
                               if cfg.bpe_cased else "bert-base-uncased")
            self.bert_tokenizer = BertTokenizer.from_pretrained(
                vocab_file_name)

    def encode(self, x: str) -> str:
        return " ".join(self.bert_tokenizer.tokenize(x))

    def decode(self, x: str) -> str:
        return self.bert_tokenizer.clean_up_tokenization(
            self.bert_tokenizer.convert_tokens_to_string(x.split(" ")))

    def is_beginning_of_word(self, x: str) -> bool:
        return not x.startswith("##")
Exemple #3
0
class NemoBertTokenizer(TokenizerSpec):
    def __init__(
            self,
            pretrained_model=None,
            vocab_file=None,
            do_lower_case=True,
            max_len=None,
            do_basic_tokenize=True,
            never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"),
    ):
        if pretrained_model:
            self.tokenizer = BertTokenizer.from_pretrained(pretrained_model)
            if "uncased" not in pretrained_model:
                self.tokenizer.basic_tokenizer.do_lower_case = False
        else:
            self.tokenizer = BertTokenizer(vocab_file, do_lower_case,
                                           do_basic_tokenize)
        self.vocab_size = len(self.tokenizer.vocab)
        self.never_split = never_split

    def text_to_tokens(self, text):
        tokens = self.tokenizer.tokenize(text)
        return tokens

    def tokens_to_text(self, tokens):
        text = self.tokenizer.convert_tokens_to_string(tokens)
        return remove_spaces(handle_quotes(text.strip()))

    def token_to_id(self, token):
        return self.tokens_to_ids([token])[0]

    def tokens_to_ids(self, tokens):
        ids = self.tokenizer.convert_tokens_to_ids(tokens)
        return ids

    def ids_to_tokens(self, ids):
        tokens = self.tokenizer.convert_ids_to_tokens(ids)
        return tokens

    def text_to_ids(self, text):
        tokens = self.text_to_tokens(text)
        ids = self.tokens_to_ids(tokens)
        return ids

    def ids_to_text(self, ids):
        tokens = self.ids_to_tokens(ids)
        tokens_clean = [t for t in tokens if t not in self.never_split]
        text = self.tokens_to_text(tokens_clean)
        return text

    def pad_id(self):
        return self.tokens_to_ids(["[PAD]"])[0]

    def bos_id(self):
        return self.tokens_to_ids(["[CLS]"])[0]

    def eos_id(self):
        return self.tokens_to_ids(["[SEP]"])[0]
Exemple #4
0
class BertBPE(object):
    @staticmethod
    def add_args(parser):
        # fmt: off
        parser.add_argument('--bpe-cased',
                            action='store_true',
                            help='set for cased BPE',
                            default=False)
        parser.add_argument('--bpe-vocab-file',
                            type=str,
                            help='bpe vocab file.')
        # fmt: on

    def __init__(self, args):
        try:
            from transformers import BertTokenizer
            from pytorch_transformers.tokenization_utils import clean_up_tokenization
        except ImportError:
            raise ImportError(
                'Please install 1.0.0 version of pytorch_transformers'
                'with: pip install pytorch-transformers')

        if 'bpe_vocab_file' in args:
            self.bert_tokenizer = BertTokenizer(
                args.bpe_vocab_file, do_lower_case=not args.bpe_cased)
        else:
            vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased'
            self.bert_tokenizer = BertTokenizer.from_pretrained(
                vocab_file_name)
            self.clean_up_tokenization = clean_up_tokenization

    def encode(self, x: str) -> str:
        return ' '.join(self.bert_tokenizer.tokenize(x))

    def decode(self, x: str) -> str:
        return self.clean_up_tokenization(
            self.bert_tokenizer.convert_tokens_to_string(x.split(' ')))

    def is_beginning_of_word(self, x: str) -> bool:
        return not x.startswith('##')
Exemple #5
0
                                          bos_token="<BOS>",
                                          eos_token="<EOS>")
tokenizer.tokenize("i like tea")
special_tokens = {"bos_token": "<BOS>", "eos_token": "<EOS>"}
tokenizer.add_special_tokens(special_tokens)

tokenizer.bos_token_id
tokenizer.eos_token_id
tokenizer.all_special_ids

tokenizer.special_tokens_map
tokenizer.additional_special_tokens
y = "<BOS> I like embeddings <EOS> [SEP] i like tea"
z = tokenizer.encode(y)
tokenizer.convert_ids_to_tokens(z)
tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(z))

tokenizer.encode("embeddings embedding")
tokenizer.encode("i like tea")
tokenizer.encode("i like tea")
tokenizer.decode(tokenizer.encode("embeddings embedding"))

tokenizer.get_special_tokens_mask([100, 101, 102], [1, 2, 3])
tokenizer.get_special_tokens_mask([100, 101, 102, 1, 2, 3])

tokenizer("s")
from transformers import BertTokenizerFast
t1 = BertTokenizerFast.from_pretrained("bert-base-uncased",
                                       bos_token="<BOS>",
                                       eos_token="<EOS>")
Exemple #6
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.logger = self.get_logger('Trainer')
        self.random = random.Random(self.args.seed)

        self.device = torch.device(self.args.device_id)
        torch.cuda.set_device(self.device)
        self.logger.info(f'Use device {self.device}')

        self.prepare_stuff()

        self.tokenizer = BertTokenizer(self.args.vocab_path)
        self.tokenizer.add_special_tokens({'bos_token': '[BOS]'})

        self.model, self.optim = self.load_model()
        self.train_batches_all, self.test_batches_all = self.load_batches()

    def load_model(self):
        if self.args.pretrain_path == '' or self.args.resume:
            self.args.pretrain_path = None
        model_args = dict(
            tokenizer=self.tokenizer,
            max_decode_len=self.args.max_decode_len,
            pretrain_path=self.args.pretrain_path,
        )
        if self.args.use_keywords:
            self.logger.info('Creating KWSeq2Seq model...')
            self.model = KWSeq2Seq(**model_args)
        else:
            self.logger.info('Creating Seq2Seq model...')
            self.model = Seq2Seq(**model_args)

        self.logger.info(f'Moving model to {self.device}...')
        self.model = self.model.to(self.device)
        self.logger.info(f'Moving model to {self.device} done.')
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)

        self.model, self.optim = apex.amp.initialize(self.model,
                                                     self.optim,
                                                     opt_level='O2',
                                                     verbosity=0)

        if self.args.resume:
            self.model, self.optim = self.resume()

        if self.args.n_gpu > 1:
            torch.distributed.init_process_group(
                backend='nccl',
                init_method='tcp://127.0.0.1:29500',
                rank=self.args.rank,
                world_size=self.args.n_gpu,
            )
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.args.rank],
                output_device=self.args.rank,
                find_unused_parameters=True,
            )
        return self.model, self.optim

    def load_batches(self):
        self.train_batches_all = None
        self.test_batches_all = None

        if self.args.train_pickle_path:
            with open(self.args.train_pickle_path, 'rb') as f:
                batches = pickle.load(f)
            n, m = len(batches), self.args.n_gpu
            r = (n + m - 1) // m * m - n
            self.train_batches_all = batches + batches[:r]

        if self.args.test_pickle_path:
            with open(self.args.test_pickle_path, 'rb') as f:
                self.test_batches_all = pickle.load(f)

        return self.train_batches_all, self.test_batches_all

    def get_logger(self, class_name):
        colors = ['', '\033[92m', '\033[93m', '\033[94m']
        reset_color = '\033[0m'
        self.logger = logging.getLogger(__name__ + '.' + class_name)
        self.logger.setLevel(logging.INFO)
        hander = logging.StreamHandler()
        hander.setLevel(logging.INFO)
        s = f'[%(levelname)s] [{class_name}.%(funcName)s] %(message)s'
        if self.args.n_gpu > 1:
            s = f'[Rank {self.args.rank}] ' + s
            s = f' {colors[self.args.rank % 4]}' + s + reset_color
        formatter = logging.Formatter(s)
        hander.setFormatter(formatter)
        self.logger.addHandler(hander)
        return self.logger

    def prepare_stuff(self):
        # Recoder
        self.recoder = Recoder(tag=self.args.tag,
                               clear=self.args.clear,
                               uri=self.args.mongodb_uri,
                               db=self.args.mongodb_db)

        # Tensorboard
        if self.args.is_worker:
            self.tensorboard_dir = os.path.join(self.args.tensorboard_base_dir,
                                                self.args.tag)
            if self.args.clear and os.path.exists(self.tensorboard_dir):
                shutil.rmtree(self.tensorboard_dir)
                self.logger.info(f'Clear "{self.tensorboard_dir}"')
                time.sleep(1)
            self.writer = SummaryWriter(self.tensorboard_dir)

        # Checkpoint
        self.checkpoints_dir = os.path.join(self.args.checkpoints_base_dir,
                                            self.args.tag)
        if self.args.is_worker:
            if self.args.resume:
                assert os.path.exists(self.checkpoints_dir)
            else:
                if self.args.clear and os.path.exists(self.checkpoints_dir):
                    shutil.rmtree(self.checkpoints_dir)
                    self.logger.info(f'Clear "{self.checkpoints_dir}"')
                os.makedirs(self.checkpoints_dir)

    def resume(self):
        ckpts = {
            int(os.path.splitext(name)[0].split('-')[-1]): name
            for name in os.listdir(self.checkpoints_dir)
        }
        assert len(ckpts) > 0
        self.epoch = max(ckpts)
        path = os.path.join(self.checkpoints_dir, ckpts[self.epoch])

        self.logger.info(f'Resume from "{path}"')
        state_dict = torch.load(path, map_location=self.device)
        self.model.load_state_dict(state_dict['model'], strict=True)
        self.optim.load_state_dict(state_dict['optim'])
        apex.amp.load_state_dict(state_dict['amp'])
        return self.model, self.optim

    def loss_fn(self, input, target):
        loss = torch.nn.functional.cross_entropy(
            input=input.reshape(-1, input.size(-1)),
            target=target[:, 1:].reshape(-1),
            ignore_index=self.tokenizer.pad_token_id,
            reduction='mean',
        )
        return loss

    def train_batch(self, batch):
        # Forward & Loss
        batch.to(self.device)
        if self.args.use_keywords:
            logits, k_logits = self.model(mode='train',
                                          x=batch.x,
                                          y=batch.y,
                                          k=batch.k)
            y_loss = self.loss_fn(input=logits, target=batch.y)
            k_loss = self.loss_fn(input=k_logits, target=batch.k)
            loss = self.args.y_loss_weight * y_loss + self.args.k_loss_weight * k_loss
        else:
            logits = self.model(mode='train', x=batch.x, y=batch.y)
            loss = self.loss_fn(input=logits, target=batch.y)

        # Backward & Optim
        loss = loss / self.args.n_accum_batches
        with apex.amp.scale_loss(loss, self.optim) as scaled_loss:
            scaled_loss.backward()
        self.train_steps += 1
        if self.train_steps % self.args.n_accum_batches == 0:
            self.optim.step()
            self.model.zero_grad()

        # Log
        if self.args.is_worker:
            self.writer.add_scalar('_Loss/all', loss.item(), self.train_steps)
            if self.args.use_keywords:
                self.writer.add_scalar('_Loss/y_loss', y_loss.item(),
                                       self.train_steps)
                self.writer.add_scalar('_Loss/k_loss', k_loss.item(),
                                       self.train_steps)

        if self.train_steps % self.args.case_interval == 0:
            y_pred_ids = logits.argmax(dim=-1).tolist()
            y_pred = self.batch_ids_to_strings(y_pred_ids)
            if self.args.use_keywords:
                k_pred_ids = k_logits.argmax(dim=-1).tolist()
                k_pred = self.batch_ids_to_tokens(k_pred_ids)
            else:
                k_pred = None
            self.recoder.record(
                mode='train',
                epoch=self.epoch,
                step=self.train_steps,
                rank=self.args.rank,
                texts_x=batch.texts_x[0],
                text_y=batch.text_y[0],
                y_pred=y_pred[0],
                tokens_k=batch.tokens_k[0] if 'tokens_k' in batch else None,
                k_pred=k_pred[0] if k_pred is not None else None,
            )
        return loss.item()

    def train_epoch(self):
        self.model.train()
        self.model.zero_grad()
        if self.args.is_worker:
            pbar = tqdm(
                self.train_batches,
                desc=f'[{self.args.tag}] [{self.device}] Train {self.epoch}',
                dynamic_ncols=True)
        else:
            pbar = self.train_batches
        for batch in pbar:
            loss = self.train_batch(batch)  # Core training
            if self.args.is_worker:
                pbar.set_postfix({'loss': f'{loss:.4f}'})

        # Checkpoint
        if self.args.is_worker:
            if isinstance(self.model,
                          torch.nn.parallel.DistributedDataParallel):
                model_state_dict = self.model.module.state_dict()
            else:
                model_state_dict = self.model.state_dict()
            state_dict = {
                'model': model_state_dict,
                'optim': self.optim.state_dict(),
                'amp': apex.amp.state_dict(),
            }
            path = os.path.join(self.checkpoints_dir,
                                f'{self.args.tag}-epoch-{self.epoch}.pt')
            torch.save(state_dict, path)

    def fit(self):
        if not self.args.resume:
            self.epoch = 0
            self.train_steps, self.test_steps = 0, 0
        else:
            if self.train_batches_all is not None:
                self.train_steps = self.epoch * len(
                    self.train_batches_all) // self.args.n_gpu
            if self.test_batches_all is not None:
                self.test_steps = self.epoch * len(
                    self.test_batches_all) // self.args.n_gpu

        if self.test_batches_all is not None:
            self.test_batches = self.test_batches_all[self.args.rank::self.
                                                      args.n_gpu]
            if self.args.is_worker:
                self.recoder.record_target(self.test_batches_all)

        while True:
            self.epoch += 1

            if self.train_batches_all is not None:
                self.random.shuffle(self.train_batches_all)
                self.train_batches = self.train_batches_all[
                    self.args.rank::self.args.n_gpu]
                if self.args.n_gpu > 1:
                    torch.distributed.barrier()
                self.train_epoch()

            if self.test_batches_all is not None:
                if self.args.n_gpu > 1:
                    torch.distributed.barrier()
                results = self.test_epoch()
                self.recoder.record_output(results)

    def test_epoch(self):
        self.model.eval()
        results = []
        if self.args.is_worker:
            pbar = tqdm(
                self.test_batches,
                desc=f'[{self.args.tag}] [{self.device}] Test {self.epoch}',
                dynamic_ncols=True)
        else:
            pbar = self.test_batches
        for batch in pbar:
            batch_results = self.test_batch(batch)  # Core testing
            results += batch_results
            if self.args.is_worker:
                pbar.set_postfix({'step': self.test_steps})
        return results

    def test_batch(self, batch):
        batch.to(self.device)
        results = []
        if self.args.use_keywords:
            y_pred_ids, k_pred_ids = self.model(mode='test', x=batch.x)
            y_pred = self.batch_ids_to_strings(y_pred_ids.tolist())
            k_pred = self.batch_ids_to_tokens(k_pred_ids.tolist())
            for i, y, k in zip(batch.index, y_pred, k_pred):
                results.append({
                    'epoch': self.epoch,
                    'index': i,
                    'rank': self.args.rank,
                    'y': y,
                    'k': k
                })
        else:
            y_pred_ids = self.model(mode='test', x=batch.x)
            y_pred = self.batch_ids_to_strings(y_pred_ids.tolist())
            for i, y in zip(batch.index, y_pred):
                results.append({
                    'epoch': self.epoch,
                    'index': i,
                    'rank': self.args.rank,
                    'y': y
                })
            k_pred = None

        self.test_steps += 1
        if self.test_steps % self.args.case_interval == 0:
            self.recoder.record(
                mode='test',
                epoch=self.epoch,
                step=self.test_steps,
                rank=self.args.rank,
                texts_x=batch.texts_x[0],
                text_y=batch.text_y[0],
                y_pred=y_pred[0],
                tokens_k=batch.tokens_k[0] if 'tokens_k' in batch else None,
                k_pred=k_pred[0] if k_pred is not None else None,
            )
        return results

    def batch_ids_to_strings(self, batch_ids):
        strings = []
        for ids in batch_ids:
            tokens = self.ids_to_tokens(ids)
            string = self.tokenizer.convert_tokens_to_string(tokens)
            strings.append(string)
        return strings

    def batch_ids_to_tokens(self, batch_ids):
        return [self.ids_to_tokens(ids) for ids in batch_ids]

    def ids_to_tokens(self, ids):
        tokens = self.tokenizer.convert_ids_to_tokens(ids)
        if tokens[0] == self.tokenizer.bos_token:
            tokens = tokens[1:]
        if tokens.count(self.tokenizer.sep_token) > 0:
            sep_pos = tokens.index(self.tokenizer.sep_token)
            tokens = tokens[:sep_pos]
        return tokens
Exemple #7
0
class DataHelper(object):
    def __init__(self):
        self.tokenizer = BertTokenizer(config.RSC_DIR +
                                       '/char_vocab/vocab.txt',
                                       do_lower_case=False,
                                       do_basic_tokenize=False)
        self.data_load()

    def data_load(self):
        train_file = config.RSC_DIR + '/sbd/news/news_train.txt'
        valid_file = config.RSC_DIR + '/sbd/news/news_valid.txt'
        test_file = config.RSC_DIR + '/sbd/news/news_test.txt'
        target_vocab = config.RSC_DIR + '/sbd/news/target_vocab.txt'
        self.vocab_load(target_vocab)
        self.train_data = self.file_load(train_file)
        self.valid_data = self.file_load(valid_file)
        self.test_data = self.file_load(test_file)

    def vocab_load(self, file_path):
        self.target2idx = {}
        self.idx2target = {}
        with open(file_path, encoding='utf-8') as f:
            for token in f:
                token = token.strip()
                idx = len(self.target2idx)
                self.target2idx[token] = idx
                self.idx2target[idx] = token

    def text2tokens(self, text):
        tokens = []
        for sen in text.split(' / '):
            tokens += self.tokenizer.tokenize(sen)
        return tokens

    def text2targets(self, text):
        targets = []
        for sen in text.split(' / '):
            targets += self.text2target(self.tokenizer.tokenize(sen))
        return targets

    def tokens2text(self, minibatch_input, minibatch_pred):
        result = []
        for input, pred in zip(minibatch_input, minibatch_pred):
            ids = []
            batch_result = []
            for i, p in zip(input, pred):
                if i == self.tokenizer.vocab['[CLS]']:
                    continue
                elif p == self.target2idx['O']:
                    ids.append(i)
                    tokens = self.tokenizer.convert_ids_to_tokens(ids)
                    batch_result.append(
                        self.tokenizer.convert_tokens_to_string(tokens))
                    ids = []
                else:
                    ids.append(i)
            result.append(batch_result)
        return result

    def text2target(self, tokens):
        target = []
        for i, token in enumerate(tokens):
            if i == 0:
                target.append('B')
            elif i == len(tokens) - 1:
                target.append('O')
            else:
                target.append('I')
        return target

    def file_load(self, file_path):
        input = data.Field(sequential=True,
                           use_vocab=True,
                           tokenize=self.text2tokens,
                           init_token='[CLS]',
                           pad_token='[PAD]',
                           unk_token='[UNK]',
                           lower=False,
                           batch_first=True)
        target = data.Field(sequential=True,
                            use_vocab=True,
                            tokenize=self.text2targets,
                            init_token='[CLS]',
                            pad_token='[PAD]',
                            unk_token='[UNK]',
                            lower=False,
                            batch_first=True)

        dataset = data.TabularDataset(path=file_path,
                                      format='tsv',
                                      fields=[('input', input),
                                              ('target', target)],
                                      skip_header=False)

        input.build_vocab(dataset)
        target.build_vocab(dataset)
        input.vocab.stoi = self.tokenizer.vocab
        input.vocab.itos = self.tokenizer.ids_to_tokens
        target.vocab.stoi = self.target2idx
        target.vocab.itos = self.idx2target

        return dataset