def print_statistics(self): """ Print some statistics on the corpus. Only the master process. """ if not self.params.is_master: return logger.info(f'{len(self)} sequences')
def remove_long_sequences(self): """ Sequences that are too long are splitted by chunk of max_position_embeddings. """ indices = self.lengths >= self.params.max_position_embeddings logger.info(f'Splitting {sum(indices)} too long sequences.') def divide_chunks(l, n): return [l[i:i + n] for i in range(0, len(l), n)] new_tok_ids = [] new_lengths = [] cls_id, sep_id = self.params.special_tok_ids[ 'cls_token'], self.params.special_tok_ids['sep_token'] max_len = self.params.max_position_embeddings for seq_, len_ in zip(self.token_ids, self.lengths): if len_ <= max_len: new_tok_ids.append(seq_) new_lengths.append(len_) else: sub_seqs = [] for sub_s in divide_chunks(seq_, max_len - 2): if sub_s[0] != cls_id: sub_s = np.insert(sub_s, 0, cls_id) if sub_s[-1] != sep_id: sub_s = np.insert(sub_s, len(sub_s), sep_id) assert len(sub_s) <= max_len sub_seqs.append(sub_s) new_tok_ids.extend(sub_seqs) new_lengths.extend([len(l) for l in sub_seqs]) self.token_ids = np.array(new_tok_ids) self.lengths = np.array(new_lengths)
def get_iterator(self): """ Initialize the data iterator. Each process has its own data iterator (iterating on his own random portion of the dataset). """ logger.info('--- Initializing Data Iterator') set_seed(self.params) self.data_iterator = tqdm(self.dataloader, desc="Iteration", disable=self.params.local_rank not in [-1, 0], total=(self.params.max_steps % self.num_steps_epoch if self.params.max_steps > 0 else None))
def split(self): """ Distributed training: split the data accross the processes. """ assert self.params.n_gpu > 1 logger.info('Splitting the data accross the processuses.') n_seq = len(self) n_seq_per_procesus = n_seq // self.params.world_size a = n_seq_per_procesus * self.params.global_rank b = a + n_seq_per_procesus self.select_data(a=a, b=b)
def get_iterator(self, seed: int = None): """ Initialize the data iterator. Each process has its own data iterator (iterating on his own random portion of the dataset). Input: ------ seed: `int` - The random seed. """ logger.info('--- Initializing Data Iterator') self.data_iterator = self.dataloader.get_iterator(seed=seed)
def remove_empty_sequences(self): """ Too short sequences are simply removed. This could be tunedd. """ init_size = len(self) indices = self.lengths > 11 self.token_ids = self.token_ids[indices] self.lengths = self.lengths[indices] new_size = len(self) logger.info( f'Remove {init_size - new_size} too short (<=11 tokens) sequences.' )
def select_data(self, a: int, b: int): """ Select a subportion of the data. """ n_sequences = len(self) assert 0 <= a < b <= n_sequences, ValueError( f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}') logger.info(f'Selecting sequences from {a} to {b} (excluded).') self.token_ids = self.token_ids[a:b] self.lengths = self.lengths[a:b] self.check()
def end_epoch(self): """ Finally arrived at the end of epoch (full pass on dataset). Do some tensorboard logging and checkpoint saving. """ logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.') self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth') self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch) self.epoch += 1 self.n_sequences_epoch = 0 self.n_iter = 0 self.total_loss_epoch = 0
def train(self): """ The real training loop. """ if self.is_master: logger.info('Starting training') self.last_log = time.time() self.student.train() self.teacher.eval() for _ in range(self.params.n_epoch): if self.is_master: logger.info( f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') if self.multi_gpu: torch.distributed.barrier() iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) for __ in range(self.num_steps_epoch): batch = self.get_batch() if self.params.n_gpu > 0: batch = tuple( t.to(f'cuda:{self.params.local_rank}') for t in batch) token_ids, attn_mask, mlm_labels = self.prepare_batch( batch=batch) self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels) iter_bar.update() iter_bar.set_postfix({ 'Last_loss': f'{self.last_loss:.2f}', 'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}' }) iter_bar.close() if self.is_master: logger.info( f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}') self.end_epoch() if self.is_master: logger.info(f'Save very last checkpoint as `pytorch_model.bin`.') self.save_checkpoint(checkpoint_name=f'pytorch_model.bin') logger.info('Training is finished')
def __init__(self, params: dict, dataloader: Dataset, token_probs: torch.tensor, student: nn.Module, teacher: nn.Module): logger.info('Initializing Distiller') self.params = params self.dump_path = params.dump_path self.multi_gpu = params.multi_gpu self.fp16 = params.fp16 self.student = student self.teacher = teacher self.dataloader = dataloader if self.params.n_gpu > 1: self.dataloader.split() self.get_iterator(seed=params.seed) self.temperature = params.temperature assert self.temperature > 0. self.alpha_ce = params.alpha_ce self.alpha_mlm = params.alpha_mlm self.alpha_mse = params.alpha_mse self.alpha_cos = params.alpha_cos assert self.alpha_ce >= 0. assert self.alpha_mlm >= 0. assert self.alpha_mse >= 0. assert self.alpha_cos >= 0. assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0. self.mlm_mask_prop = params.mlm_mask_prop assert 0.0 <= self.mlm_mask_prop <= 1.0 assert params.word_mask + params.word_keep + params.word_rand == 1.0 self.pred_probs = torch.FloatTensor( [params.word_mask, params.word_keep, params.word_rand]) self.pred_probs = self.pred_probs.to( f'cuda:{params.local_rank}' ) if params.n_gpu > 0 else self.pred_probs self.token_probs = token_probs.to( f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs if self.fp16: self.pred_probs = self.pred_probs.half() self.token_probs = self.token_probs.half() self.epoch = 0 self.n_iter = 0 self.n_total_iter = 0 self.n_sequences_epoch = 0 self.total_loss_epoch = 0 self.last_loss = 0 self.last_loss_ce = 0 self.last_loss_mlm = 0 if self.alpha_mse > 0.: self.last_loss_mse = 0 if self.alpha_cos > 0.: self.last_loss_cos = 0 self.last_log = 0 self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) if self.alpha_mse > 0.: self.mse_loss_fct = nn.MSELoss(reduction='sum') if self.alpha_cos > 0.: self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean') logger.info('--- Initializing model optimizer') assert params.gradient_accumulation_steps >= 1 self.num_steps_epoch = int( len(self.dataloader) / params.batch_size) + 1 num_train_optimization_steps = int( self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad ], 'weight_decay': params.weight_decay }, { 'params': [ p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad ], 'weight_decay': 0.0 }] logger.info( "------ Number of trainable parameters (student): %i" % sum([ p.numel() for p in self.student.parameters() if p.requires_grad ])) logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()])) self.optimizer = AdamW(optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)) warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) self.scheduler = WarmupLinearSchedule( self.optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps) if self.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) logger.info( f"Using fp16 training: {self.params.fp16_opt_level} level") self.student, self.optimizer = amp.initialize( self.student, self.optimizer, opt_level=self.params.fp16_opt_level) self.teacher = self.teacher.half() if self.multi_gpu: if self.fp16: from apex.parallel import DistributedDataParallel logger.info( "Using apex.parallel.DistributedDataParallel for distributed training." ) self.student = DistributedDataParallel(self.student) else: from torch.nn.parallel import DistributedDataParallel logger.info( "Using nn.parallel.DistributedDataParallel for distributed training." ) self.student = DistributedDataParallel( self.student, device_ids=[params.local_rank], output_device=params.local_rank) self.is_master = params.is_master if self.is_master: logger.info('--- Initializing Tensorboard') self.tensorboard = SummaryWriter( log_dir=os.path.join(self.dump_path, 'log', 'train')) self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0)
def __init__(self, params: dict, dataloader: Dataset, student: nn.Module, teacher: nn.Module, tokenizer: nn.Module): logger.info('Initializing Distiller') self.params = params self.output_dir = params.output_dir # self.multi_gpu = params.multi_gpu self.student = student self.teacher = teacher self.tokenizer = tokenizer self.dataloader = dataloader # if self.params.n_gpu > 1: # self.dataloader.split() # self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1 self.num_steps_epoch = len(self.dataloader) # print(len(self.dataloader), params.batch_size) # print(self.num_steps_epoch, params.gradient_accumulation_steps, params.n_epoch) # exit(0) self.get_iterator() self.temperature = params.temperature assert self.temperature > 0. self.alpha_ce = params.alpha_ce self.alpha_mse = params.alpha_mse self.alpha_cos = params.alpha_cos assert self.alpha_ce >= 0. assert self.alpha_mse >= 0. assert self.alpha_cos >= 0. assert self.alpha_ce + self.alpha_mse + self.alpha_cos > 0. self.epoch = 0 self.n_iter = 0 self.n_total_iter = 0 self.n_sequences_epoch = 0 self.total_loss_epoch = 0 self.last_loss = 0 self.last_loss_ce = 0 if self.alpha_mse > 0.: self.last_loss_mse = 0 if self.alpha_cos > 0.: self.last_loss_cos = 0 self.last_log = 0 self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') if self.alpha_mse > 0.: self.mse_loss_fct = nn.MSELoss(reduction='sum') if self.alpha_cos > 0.: self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean') logger.info('--- Initializing model optimizer') assert params.gradient_accumulation_steps >= 1 num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay}, {'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0} ] logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad])) logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()])) self.optimizer = AdamW(optimizer_grouped_parameters, lr=params.learning_rate, eps=params.adam_epsilon, betas=(0.9, 0.98)) warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) self.scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps) # print("TOTAL", num_train_optimization_steps) # print("WARM UP", warmup_steps) # exit(0) # if self.multi_gpu: # if self.fp16: # from apex.parallel import DistributedDataParallel # logger.info("Using apex.parallel.DistributedDataParallel for distributed training.") # self.student = DistributedDataParallel(self.student) # else: # from torch.nn.parallel import DistributedDataParallel # logger.info("Using nn.parallel.DistributedDataParallel for distributed training.") # self.student = DistributedDataParallel(self.student, # device_ids=[params.local_rank], # output_device=params.local_rank) logger.info('--- Initializing Tensorboard') current_time = datetime.now().strftime('%b%d_%H-%M-%S') logdir = os.path.join(self.params.output_dir, "tensorboard", current_time + '_' + socket.gethostname()) self.tensorboard = SummaryWriter(log_dir=logdir, flush_secs=60) self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0)
def train(self): """ The real training loop. """ logger.info('Starting training') self.last_log = time.time() self.student.train() self.teacher.eval() do_stop = False for epoch_number in range(self.params.n_epoch): logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') # if self.multi_gpu: # torch.distributed.barrier() for step, batch in enumerate(self.data_iterator): if self.params.n_gpu > 0: batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch) batch = tuple(t.to(self.params.device) for t in batch) self.step(batch) # print(self.params.local_rank) # exit(0) if self.n_total_iter % self.params.log_interval == 0 and self.params.local_rank in [-1, 0] and self.params.evaluate_during_training: results = evaluate(self.params, self.student, self.tokenizer, prefix="e{}s{}".format(epoch_number, step)) for key, value in results.items(): if key == "conf_mtrx": continue self.tensorboard.add_scalar('eval_{}'.format(key), value, global_step=self.n_total_iter) # iter_bar.update() # iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}', # 'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'}) if self.params.max_steps > 0 and self.n_total_iter + step > self.params.max_steps: self.data_iterator.close() do_stop = True break if do_stop: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1} early due to max_steps') self.end_epoch() break self.data_iterator.close() self.data_iterator = tqdm(self.dataloader, desc="Iteration", disable=self.params.local_rank not in [-1, 0]) logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}') self.end_epoch() logger.info(f'Save very last checkpoint as `pytorch_model.bin`.') self.save_checkpoint() self.tensorboard.close() logger.info('Training is finished')
from collections import Counter import argparse import pickle from examples.distillation.utils import logger if __name__ == '__main__': parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)") parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset.") parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file.") parser.add_argument("--vocab_size", default=30522, type=int) args = parser.parse_args() logger.info(f'Loading data from {args.data_file}') with open(args.data_file, 'rb') as fp: data = pickle.load(fp) logger.info('Counting occurences for MLM.') counter = Counter() for tk_ids in data: counter.update(tk_ids) counts = [0]*args.vocab_size for k, v in counter.items(): counts[k] = v logger.info(f'Dump to {args.token_counts_dump}') with open(args.token_counts_dump, 'wb') as handle: pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
def main(): parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser.add_argument('--file_path', type=str, default='data/dump.txt', help='The path to the data.') parser.add_argument('--bert_tokenizer', type=str, default='bert-base-uncased', help="The tokenizer to use.") parser.add_argument('--dump_file', type=str, default='data/dump', help='The dump file prefix.') args = parser.parse_args() logger.info(f'Loading Tokenizer ({args.bert_tokenizer})') bert_tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) logger.info(f'Loading text from {args.file_path}') with open(args.file_path, 'r', encoding='utf8') as fp: data = fp.readlines() logger.info(f'Start encoding') logger.info(f'{len(data)} examples to process.') rslt = [] iter = 0 interval = 10000 start = time.time() for text in data: text = f'[CLS] {text.strip()} [SEP]' token_ids = bert_tokenizer.encode(text) rslt.append(token_ids) iter += 1 if iter % interval == 0: end = time.time() logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl') start = time.time() logger.info('Finished binarization') logger.info(f'{len(data)} examples processed.') dp_file = f'{args.dump_file}.{args.bert_tokenizer}.pickle' rslt_ = [np.uint16(d) for d in rslt] random.shuffle(rslt_) logger.info(f'Dump to {dp_file}') with open(dp_file, 'wb') as handle: pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)