コード例 #1
0
 def print_statistics(self):
     """
     Print some statistics on the corpus. Only the master process.
     """
     if not self.params.is_master:
         return
     logger.info(f'{len(self)} sequences')
コード例 #2
0
    def remove_long_sequences(self):
        """
        Sequences that are too long are splitted by chunk of max_position_embeddings.
        """
        indices = self.lengths >= self.params.max_position_embeddings
        logger.info(f'Splitting {sum(indices)} too long sequences.')

        def divide_chunks(l, n):
            return [l[i:i + n] for i in range(0, len(l), n)]

        new_tok_ids = []
        new_lengths = []
        cls_id, sep_id = self.params.special_tok_ids[
            'cls_token'], self.params.special_tok_ids['sep_token']
        max_len = self.params.max_position_embeddings

        for seq_, len_ in zip(self.token_ids, self.lengths):
            if len_ <= max_len:
                new_tok_ids.append(seq_)
                new_lengths.append(len_)
            else:
                sub_seqs = []
                for sub_s in divide_chunks(seq_, max_len - 2):
                    if sub_s[0] != cls_id:
                        sub_s = np.insert(sub_s, 0, cls_id)
                    if sub_s[-1] != sep_id:
                        sub_s = np.insert(sub_s, len(sub_s), sep_id)
                    assert len(sub_s) <= max_len
                    sub_seqs.append(sub_s)

                new_tok_ids.extend(sub_seqs)
                new_lengths.extend([len(l) for l in sub_seqs])

        self.token_ids = np.array(new_tok_ids)
        self.lengths = np.array(new_lengths)
コード例 #3
0
 def get_iterator(self):
     """
     Initialize the data iterator.
     Each process has its own data iterator (iterating on his own random portion of the dataset).
     """
     logger.info('--- Initializing Data Iterator')
     set_seed(self.params)
     self.data_iterator = tqdm(self.dataloader, desc="Iteration", disable=self.params.local_rank not in [-1, 0],
                               total=(self.params.max_steps % self.num_steps_epoch if  self.params.max_steps > 0 else None))
コード例 #4
0
 def split(self):
     """
     Distributed training: split the data accross the processes.
     """
     assert self.params.n_gpu > 1
     logger.info('Splitting the data accross the processuses.')
     n_seq = len(self)
     n_seq_per_procesus = n_seq // self.params.world_size
     a = n_seq_per_procesus * self.params.global_rank
     b = a + n_seq_per_procesus
     self.select_data(a=a, b=b)
コード例 #5
0
    def get_iterator(self, seed: int = None):
        """
        Initialize the data iterator.
        Each process has its own data iterator (iterating on his own random portion of the dataset).

        Input:
        ------
            seed: `int` - The random seed.
        """
        logger.info('--- Initializing Data Iterator')
        self.data_iterator = self.dataloader.get_iterator(seed=seed)
コード例 #6
0
 def remove_empty_sequences(self):
     """
     Too short sequences are simply removed. This could be tunedd.
     """
     init_size = len(self)
     indices = self.lengths > 11
     self.token_ids = self.token_ids[indices]
     self.lengths = self.lengths[indices]
     new_size = len(self)
     logger.info(
         f'Remove {init_size - new_size} too short (<=11 tokens) sequences.'
     )
コード例 #7
0
    def select_data(self, a: int, b: int):
        """
        Select a subportion of the data.
        """
        n_sequences = len(self)
        assert 0 <= a < b <= n_sequences, ValueError(
            f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}')

        logger.info(f'Selecting sequences from {a} to {b} (excluded).')
        self.token_ids = self.token_ids[a:b]
        self.lengths = self.lengths[a:b]

        self.check()
コード例 #8
0
    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.')

        self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth')
        self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0
コード例 #9
0
    def train(self):
        """
        The real training loop.
        """
        if self.is_master: logger.info('Starting training')
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master:
                logger.info(
                    f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
            if self.multi_gpu:
                torch.distributed.barrier()

            iter_bar = trange(self.num_steps_epoch,
                              desc="-Iter",
                              disable=self.params.local_rank not in [-1, 0])
            for __ in range(self.num_steps_epoch):
                batch = self.get_batch()
                if self.params.n_gpu > 0:
                    batch = tuple(
                        t.to(f'cuda:{self.params.local_rank}') for t in batch)
                token_ids, attn_mask, mlm_labels = self.prepare_batch(
                    batch=batch)

                self.step(input_ids=token_ids,
                          attention_mask=attn_mask,
                          mlm_labels=mlm_labels)

                iter_bar.update()
                iter_bar.set_postfix({
                    'Last_loss':
                    f'{self.last_loss:.2f}',
                    'Avg_cum_loss':
                    f'{self.total_loss_epoch/self.n_iter:.2f}'
                })
            iter_bar.close()

            if self.is_master:
                logger.info(
                    f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
            self.end_epoch()

        if self.is_master:
            logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
            self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
            logger.info('Training is finished')
コード例 #10
0
    def __init__(self, params: dict, dataloader: Dataset,
                 token_probs: torch.tensor, student: nn.Module,
                 teacher: nn.Module):
        logger.info('Initializing Distiller')
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.dataloader = dataloader
        if self.params.n_gpu > 1:
            self.dataloader.split()
        self.get_iterator(seed=params.seed)

        self.temperature = params.temperature
        assert self.temperature > 0.

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos
        assert self.alpha_ce >= 0.
        assert self.alpha_mlm >= 0.
        assert self.alpha_mse >= 0.
        assert self.alpha_cos >= 0.
        assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0.

        self.mlm_mask_prop = params.mlm_mask_prop
        assert 0.0 <= self.mlm_mask_prop <= 1.0
        assert params.word_mask + params.word_keep + params.word_rand == 1.0
        self.pred_probs = torch.FloatTensor(
            [params.word_mask, params.word_keep, params.word_rand])
        self.pred_probs = self.pred_probs.to(
            f'cuda:{params.local_rank}'
        ) if params.n_gpu > 0 else self.pred_probs
        self.token_probs = token_probs.to(
            f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
        if self.fp16:
            self.pred_probs = self.pred_probs.half()
            self.token_probs = self.token_probs.half()

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        if self.alpha_mse > 0.: self.last_loss_mse = 0
        if self.alpha_cos > 0.: self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
        self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        if self.alpha_mse > 0.:
            self.mse_loss_fct = nn.MSELoss(reduction='sum')
        if self.alpha_cos > 0.:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')

        logger.info('--- Initializing model optimizer')
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = int(
            len(self.dataloader) / params.batch_size) + 1
        num_train_optimization_steps = int(
            self.num_steps_epoch / params.gradient_accumulation_steps *
            params.n_epoch) + 1

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in student.named_parameters()
                if not any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            params.weight_decay
        }, {
            'params': [
                p for n, p in student.named_parameters()
                if any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            0.0
        }]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)
        self.scheduler = WarmupLinearSchedule(
            self.optimizer,
            warmup_steps=warmup_steps,
            t_total=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank)

        self.is_master = params.is_master
        if self.is_master:
            logger.info('--- Initializing Tensorboard')
            self.tensorboard = SummaryWriter(
                log_dir=os.path.join(self.dump_path, 'log', 'train'))
            self.tensorboard.add_text(tag='config',
                                      text_string=str(self.params),
                                      global_step=0)
コード例 #11
0
    def __init__(self,
                 params: dict,
                 dataloader: Dataset,
                 student: nn.Module,
                 teacher: nn.Module,
                 tokenizer: nn.Module):
        logger.info('Initializing Distiller')
        self.params = params
        self.output_dir = params.output_dir
        # self.multi_gpu = params.multi_gpu

        self.student = student
        self.teacher = teacher
        self.tokenizer = tokenizer

        self.dataloader = dataloader
        # if self.params.n_gpu > 1:
        #     self.dataloader.split()
        # self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1
        self.num_steps_epoch = len(self.dataloader)
        # print(len(self.dataloader), params.batch_size)
        # print(self.num_steps_epoch, params.gradient_accumulation_steps, params.n_epoch)
        # exit(0)
        self.get_iterator()

        self.temperature = params.temperature
        assert self.temperature > 0.

        self.alpha_ce = params.alpha_ce
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos
        assert self.alpha_ce >= 0.
        assert self.alpha_mse >= 0.
        assert self.alpha_cos >= 0.
        assert self.alpha_ce + self.alpha_mse + self.alpha_cos > 0.

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        if self.alpha_mse > 0.: self.last_loss_mse = 0
        if self.alpha_cos > 0.: self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
        if self.alpha_mse > 0.:
            self.mse_loss_fct = nn.MSELoss(reduction='sum')
        if self.alpha_cos > 0.:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')

        logger.info('--- Initializing model optimizer')
        assert params.gradient_accumulation_steps >= 1
        num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay},
            {'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0}
        ]
        logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad]))
        logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                                warmup_steps=warmup_steps,
                                                t_total=num_train_optimization_steps)
        # print("TOTAL", num_train_optimization_steps)
        # print("WARM UP", warmup_steps)
        # exit(0)

        # if self.multi_gpu:
        #     if self.fp16:
        #         from apex.parallel import DistributedDataParallel
        #         logger.info("Using apex.parallel.DistributedDataParallel for distributed training.")
        #         self.student = DistributedDataParallel(self.student)
        #     else:
        #         from torch.nn.parallel import DistributedDataParallel
        #         logger.info("Using nn.parallel.DistributedDataParallel for distributed training.")
        #         self.student = DistributedDataParallel(self.student,
        #                                                device_ids=[params.local_rank],
        #                                                output_device=params.local_rank)

        logger.info('--- Initializing Tensorboard')
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        logdir = os.path.join(self.params.output_dir, "tensorboard", current_time + '_' + socket.gethostname())
        self.tensorboard = SummaryWriter(log_dir=logdir, flush_secs=60)
        self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0)
コード例 #12
0
    def train(self):
        """
        The real training loop.
        """
        logger.info('Starting training')
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()
        do_stop = False
        for epoch_number in range(self.params.n_epoch):
            logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
            # if self.multi_gpu:
            #     torch.distributed.barrier()

            for step, batch in enumerate(self.data_iterator):
                if self.params.n_gpu > 0:
                    batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch)
                batch = tuple(t.to(self.params.device) for t in batch)
                self.step(batch)

                # print(self.params.local_rank)
                # exit(0)
                if self.n_total_iter % self.params.log_interval == 0 and self.params.local_rank in [-1, 0] and self.params.evaluate_during_training:
                    results = evaluate(self.params, self.student, self.tokenizer, prefix="e{}s{}".format(epoch_number, step))
                    for key, value in results.items():
                        if key == "conf_mtrx": continue
                        self.tensorboard.add_scalar('eval_{}'.format(key), value, global_step=self.n_total_iter)

                # iter_bar.update()
                # iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}',
                #                       'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'})
                if self.params.max_steps > 0 and self.n_total_iter + step > self.params.max_steps:
                    self.data_iterator.close()
                    do_stop = True
                    break

            if do_stop:
                logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1} early due to max_steps')
                self.end_epoch()
                break

            self.data_iterator.close()
            self.data_iterator = tqdm(self.dataloader, desc="Iteration", disable=self.params.local_rank not in [-1, 0])
            
            logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
            self.end_epoch()

        logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
        self.save_checkpoint()
        self.tensorboard.close()
        logger.info('Training is finished')
コード例 #13
0
from collections import Counter
import argparse
import pickle

from examples.distillation.utils import logger

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)")
    parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle",
                        help="The binarized dataset.")
    parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle",
                        help="The dump file.")
    parser.add_argument("--vocab_size", default=30522, type=int)
    args = parser.parse_args()

    logger.info(f'Loading data from {args.data_file}')
    with open(args.data_file, 'rb') as fp:
        data = pickle.load(fp)

    logger.info('Counting occurences for MLM.')
    counter = Counter()
    for tk_ids in data:
        counter.update(tk_ids)
    counts = [0]*args.vocab_size
    for k, v in counter.items():
        counts[k] = v

    logger.info(f'Dump to {args.token_counts_dump}')
    with open(args.token_counts_dump, 'wb') as handle:
        pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
コード例 #14
0
def main():
    parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
    parser.add_argument('--file_path', type=str, default='data/dump.txt',
                        help='The path to the data.')
    parser.add_argument('--bert_tokenizer', type=str, default='bert-base-uncased',
                        help="The tokenizer to use.")
    parser.add_argument('--dump_file', type=str, default='data/dump',
                        help='The dump file prefix.')
    args = parser.parse_args()


    logger.info(f'Loading Tokenizer ({args.bert_tokenizer})')
    bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)


    logger.info(f'Loading text from {args.file_path}')
    with open(args.file_path, 'r', encoding='utf8') as fp:
        data = fp.readlines()


    logger.info(f'Start encoding')
    logger.info(f'{len(data)} examples to process.')

    rslt = []
    iter = 0
    interval = 10000
    start = time.time()
    for text in data:
        text = f'[CLS] {text.strip()} [SEP]'
        token_ids = bert_tokenizer.encode(text)
        rslt.append(token_ids)

        iter += 1
        if iter % interval == 0:
            end = time.time()
            logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl')
            start = time.time()
    logger.info('Finished binarization')
    logger.info(f'{len(data)} examples processed.')


    dp_file = f'{args.dump_file}.{args.bert_tokenizer}.pickle'
    rslt_ = [np.uint16(d) for d in rslt]
    random.shuffle(rslt_)
    logger.info(f'Dump to {dp_file}')
    with open(dp_file, 'wb') as handle:
        pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)