def train_model(model: torch.nn.Module, train_dl: torch.utils.data.DataLoader, dev_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, n_epochs: int) -> (Dict, Dict): loss_f = torch.nn.CrossEntropyLoss() best_val, best_model_weights = {'val_f1': 0}, None for ep in range(n_epochs): model.train() for batch in tqdm(train_dl, desc='Training'): optimizer.zero_grad() logits = model(batch[0]) loss = loss_f(logits, batch[1]) loss.backward() optimizer.step() val_p, val_r, val_f1, val_loss, _, _ = eval_model(model, dev_dl) current_val = { 'val_p': val_p, 'val_r': val_r, 'val_f1': val_f1, 'val_loss': val_loss, 'ep': ep } print(current_val, flush=True) if current_val['val_f1'] > best_val['val_f1']: best_val = current_val best_model_weights = model.state_dict() scheduler.step(val_loss) return best_model_weights, best_val
def restore_state(optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, args, device): latest_checkpoint_file = get_latest_checkpoint_path(args.output) if latest_checkpoint_file or args.warm_start: if args.local_rank in [None, 0]: print( 'Trying to load optimizer and scheduler state from checkpoint.' ) checkpoint = torch.load(latest_checkpoint_file or args.pre_model, map_location=device) iteration = checkpoint.get('iteration', 0) if checkpoint.get('scheduler_state_dict', None): scheduler.load_state_dict(checkpoint['scheduler_state_dict']) if args.local_rank in [None, 0]: print("Scheduler restored.") if checkpoint.get('optimizer_state_dict', None): optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if args.local_rank in [None, 0]: print("Optimizer restored.") loss = checkpoint.get('loss', 0) return iteration, loss else: return 0, 0.0
def train_one_epoch(model, train_loader: DataLoader, optimizer: AdamW, scheduler: torch.optim.lr_scheduler.LambdaLR, epoch: int, max_epoch: int): """Loops through the dataloader and trains the model and optimizer Args: model ([BertForSequenceClassification]): [description] dataloader (DataLoader): [description] optimizer (AdamW): [description] scheduler (torch.optim.lr_scheduler.LambdaLR): [description] Returns: [type]: [description] """ total_loss = 0 model.train() train_bar = tqdm(train_loader) train_bar.desc = f"Epoch {epoch}/{max_epoch} Loss: 0" for batch in train_bar: # `batch` pytorch tensors: # [0]: input ids # [1]: attention masks # [2]: labels b_input_ids = batch[0].to(device) b_input_mask = batch[1].to(device) b_labels = batch[2].to(device) model.zero_grad() outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) tmp_eval_loss, logits = outputs[:2] loss = tmp_eval_loss # torch.nn.functional.cross_entropy(logits,b_labels) total_loss += loss.item() loss.backward() utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() train_bar.desc = f"Epoch {epoch}/{max_epoch} Batch Loss: {loss.item()}" # Note cross entropy loss does not increase with batch size avg_train_loss = total_loss / len(train_loader) return model, optimizer, scheduler, avg_train_loss
def train_model(model: torch.nn.Module, train_dl: BatchSampler, dev_dl: BatchSampler, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, n_epochs: int, early_stopping: EarlyStopping) -> (Dict, Dict): if args.labels == 2: loss_f = torch.nn.BCEWithLogitsLoss() else: loss_f = torch.nn.CrossEntropyLoss() best_val, best_model_weights = {'val_f1': 0}, None for ep in range(n_epochs): losses = [] model.train() for i, batch in enumerate(tqdm(train_dl, desc='Training')): optimizer.zero_grad() prediction = model(batch[0]) loss = loss_f(prediction, batch[1]) loss.backward() optimizer.step() losses.append(loss.item()) print('Training loss:', np.mean(losses)) val_p, val_r, val_f1, val_loss = eval_model(model, dev_dl) current_val = { 'val_p': val_p, 'val_r': val_r, 'val_f1': val_f1, 'val_loss': val_loss, 'ep': ep } print(current_val, flush=True) if current_val['val_f1'] > best_val['val_f1']: best_val = current_val best_model_weights = model.state_dict() scheduler.step(val_loss) if early_stopping.step(val_loss): print('Early stopping...') break return best_model_weights, best_val
def train_model(model: torch.nn.Module, train_dl: BatchSampler, dev_dl: BatchSampler, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, n_epochs: int, labels: int = 3, early_stopping: EarlyStopping = None) -> (Dict, Dict): best_val, best_model_weights = {'val_f1': 0}, None for ep in range(n_epochs): for batch in tqdm(train_dl, desc='Training'): model.train() optimizer.zero_grad() loss, _ = model(batch[0], attention_mask=batch[1], labels=batch[2].long())[:2] loss.backward() optimizer.step() scheduler.step() val_p, val_r, val_f1, val_loss, _, _ = eval_model( model, dev_dl, labels) current_val = { 'val_f1': val_f1, 'val_p': val_p, 'val_r': val_r, 'val_loss': val_loss, 'ep': ep } print(current_val, flush=True) if current_val['val_f1'] > best_val['val_f1']: best_val = current_val best_model_weights = model.state_dict() if early_stopping and early_stopping.step(val_f1): print('Early stopping...') break return best_model_weights, best_val
def train(classifier_model: BertForMultipleChoice, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, dataloader: TensorDataset, device: torch.device, logger: logging.Logger, logging_steps: int) -> float: """ モデルの訓練を行う。 """ classifier_model.train() count_steps = 0 total_loss = 0 for batch_idx, batch in enumerate(dataloader): batch = tuple(t for t in batch) bert_inputs = { "input_ids": batch[0].to(device), "attention_mask": batch[1].to(device), "token_type_ids": batch[2].to(device), "labels": batch[3].to(device) } classifier_model.zero_grad() #Forward propagation classifier_outputs = classifier_model(**bert_inputs) loss = classifier_outputs[0] #Backward propagation loss.backward() torch.nn.utils.clip_grad_norm_(classifier_model.parameters(), 1.0) #Update parameters optimizer.step() scheduler.step() count_steps += 1 total_loss += loss.item() if batch_idx % logging_steps == 0: logger.info("Step: {}\tLoss: {}\tlr: {}".format( batch_idx, loss.item(), optimizer.param_groups[0]["lr"])) return total_loss / count_steps
def train_model(model: torch.nn.Module, train_dl: BatchSampler, dev_dl: BatchSampler, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR, n_epochs: int) -> (Dict, Dict): best_val, best_model_weights = {'val_f1': 0}, None for ep in range(n_epochs): for i, batch in enumerate(tqdm(train_dl, desc='Training')): model.train() optimizer.zero_grad() loss, _ = model(batch[0], attention_mask=batch[ 0] != tokenizer.pad_token_id, labels=batch[1]) loss.backward() optimizer.step() scheduler.step() if i in [600, 700, 100]: print(eval_model(model, dev_dl), flush=True) val_p, val_r, val_f1, val_loss = eval_model(model, dev_dl) current_val = { 'val_f1': val_f1, 'val_p': val_p, 'val_r': val_r, 'val_loss': val_loss, 'ep': ep } print(current_val, flush=True) if current_val['val_f1'] > best_val['val_f1']: best_val = current_val best_model_weights = model.state_dict() return best_model_weights, best_val
def train( config: fine_tune.config.BaseConfig, dataset: fine_tune.task.Dataset, model: fine_tune.model.Model, optimizer: torch.optim.AdamW, scheduler: torch.optim.lr_scheduler.LambdaLR, tokenizer: transformers.PreTrainedTokenizer, ): r"""Fine-tune or distill model on task specific dataset. Args: config: `fine_tune.config.BaseConfig` subclass which attributes are used for experiment setup. dataset: Task specific dataset. model: Model which will be fine-tuned on `dataset`. optimizer: `torch.optim.AdamW` optimizer. schduler: Linear warmup scheduler provided by `transformers` package. tokenizer: Tokenizer paired with `model`. """ # Training mode. model.train() # Model running device. device = config.device # Clean all gradient. optimizer.zero_grad() # Get experiment name and path. experiment_name = fine_tune.config.BaseConfig.experiment_name( experiment=config.experiment, model=config.model, task=config.task) experiment_dir = os.path.join(fine_tune.path.FINE_TUNE_EXPERIMENT, experiment_name) # Create dataloader. dataloader = torch.utils.data.DataLoader( dataset, batch_size=config.batch_size // config.accum_step, collate_fn=dataset.create_collate_fn(max_seq_len=config.max_seq_len, tokenizer=tokenizer), shuffle=True) # Create tensorboard's `SummaryWriter`. writer = torch.utils.tensorboard.SummaryWriter( os.path.join(fine_tune.path.LOG, experiment_name)) # Use cross-entropy as objective. objective = nn.CrossEntropyLoss() # Step and accumulation step counter. step = 0 accum_step = 0 total_accum_step = config.total_step * config.accum_step # Mini-batch loss and accumulate loss. # Update when accumulate to `config.batch_size`. loss = 0 accum_loss = 0 # `tqdm` CLI Logger. We will manually update progress bar. cli_logger = tqdm(desc=f'loss: {loss:.6f}', total=config.total_step) # Total update times: `config.total_step`. while accum_step < total_accum_step: # Mini-batch loop. for (input_ids, attention_mask, token_type_ids, label, _) in dataloader: # Accumulate cross-entropy loss. # Use `model(...)` to do forward pass. accum_loss = objective(input=model( input_ids=input_ids.to(device), token_type_ids=token_type_ids.to(device), attention_mask=attention_mask.to(device)), target=label.to(device)) / config.accum_step # Mini-batch cross-entropy loss. Only used as log. loss += accum_loss.item() # Backward pass accumulation loss. accum_loss.backward() # Increment accumulation step. accum_step += 1 # Perform gradient descend when achieve actual mini-batch size. if accum_step % config.accum_step == 0: # Gradient clipping. torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm) # Gradient descend. optimizer.step() # Update learning rate. scheduler.step() # Log on CLI. cli_logger.update() cli_logger.set_description(f'loss: {loss:.6f}') # Increment actual step. step += 1 # Log loss and learning rate for each `config.log_step` step. if step % config.log_step == 0: writer.add_scalar(f'{config.task}/{config.dataset}/loss', loss, step) writer.add_scalar( f'{config.task}/{config.dataset}/lr', optimizer.state_dict()['param_groups'][0]['lr'], step) # Clean up mini-batch loss. loss = 0 # Clean up gradient. optimizer.zero_grad() # Save model for each `config.ckpt_step` step. if step % config.ckpt_step == 0: torch.save( model.state_dict(), os.path.join(experiment_dir, f'model-{step}.pt')) # Stop training condition. if accum_step >= total_accum_step: break # Release IO resources. writer.flush() writer.close() cli_logger.close() # Save the lastest model. torch.save(model.state_dict(), os.path.join(experiment_dir, f'model-{step}.pt'))
def train_epoch(self, model: Reader, optimizer: torch.optim.Optimizer, scaler: GradScaler, train: DataLoader, val: DataLoader, scheduler: torch.optim.lr_scheduler.LambdaLR) -> float: """ Performs one training epoch. :param model: The model you are training. :type model: Reader :param optimizer: Use this optimizer for training. :type optimizer: torch.optim.Optimizer :param scaler: Scaler for gradients when the mixed precision is used. :type scaler: GradScaler :param train: The train dataset loader. :type train: DataLoader :param val: The validation dataset loader. :type val: DataLoader :param scheduler: Learning rate scheduler. :type scheduler: torch.optim.lr_scheduler.LambdaLR :return: Best achieved exact match among validations. :rtype: float """ model.train() loss_sum = 0 samples = 0 startTime = time.time() total_tokens = 0 optimizer.zero_grad() initStep = 0 if self.resumeSkip is not None: initStep = self.resumeSkip self.resumeSkip = None iterator = tqdm(enumerate(train), total=len(train), initial=initStep) bestExactMatch = 0.0 for current_it, batch in iterator: batch: ReaderBatch lastScale = scaler.get_scale() self.n_iter += 1 batchOnDevice = batch.to(self.device) samples += 1 try: with torch.cuda.amp.autocast( enabled=self.config["mixed_precision"]): startScores, endScores, jointScore, selectionScore = self._useModel( model, batchOnDevice) # according to the config we can get following loss combinations # join components # independent components # join components with HardEM # independent components with HardEM logSpanProb = None if not self.config["independent_components_in_loss"]: # joined components in loss logSpanProb = Reader.scores2logSpanProb( startScores, endScores, jointScore, selectionScore) # User may want to use hardEMLoss with certain probability. # In the original article it is not written clearly and it seams like it is the other way around. # After I had consulted it with authors the idea became clear. if self.config["hard_em_steps"] > 0 and \ random.random() <= min(self.update_it/self.config["hard_em_steps"], self.config["max_hard_em_prob"]): # loss is calculated for the max answer span with max probability if self.config["independent_components_in_loss"]: loss = Reader.hardEMIndependentComponentsLoss( startScores, endScores, jointScore, selectionScore, batchOnDevice.answersMask) else: loss = Reader.hardEMLoss(logSpanProb, batchOnDevice.answersMask) else: # loss is calculated for all answer spans if self.config["independent_components_in_loss"]: loss = Reader.marginalCompoundLossWithIndependentComponents( startScores, endScores, jointScore, selectionScore, batchOnDevice.answersMask) else: loss = Reader.marginalCompoundLoss( logSpanProb, batchOnDevice.answersMask) if self.config[ "use_auxiliary_loss"] and batch.isGroundTruth: # we must be sure that user wants it and that the true passage is ground truth loss += Reader.auxiliarySelectedLoss(selectionScore) loss_sum += loss.item() scaler.scale(loss).backward() # Catch out-of-memory errors except RuntimeError as e: if "CUDA out of memory." in str(e): torch.cuda.empty_cache() logging.error(e) tb = traceback.format_exc() logging.error(tb) continue else: raise e # update parameters scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( filter(lambda p: p.requires_grad, model.parameters()), self.config["max_grad_norm"]) scaler.step(optimizer) scaler.update() optimizer.zero_grad() self.update_it += 1 if math.isclose(lastScale, scaler.get_scale(), rel_tol=1e-6) and scheduler is not None: # we should not perform scheduler step when the optimizer step was omitted due to the # change of scale factor scheduler.step() if self.update_it % self.config["validate_after_steps"] == 0: valLoss, exactMatch, passageMatch, samplesWithLoss = self.validate( model, val) logging.info( f"Steps:{self.update_it}, Training loss: {loss_sum / samples:.5f}, Validation loss: {valLoss} (samples with loss {samplesWithLoss} [{samplesWithLoss / len(val):.1%}]), Exact match: {exactMatch:.5f}, Passage match: {passageMatch:.5f}" ) bestExactMatch = max(exactMatch, bestExactMatch) if self.update_it > self.config["first_save_after_updates_K"]: checkpoint = Checkpoint( model.module if isinstance(model, DataParallel) else model, optimizer, scheduler, train.sampler.actPerm, current_it + 1, self.config, self.update_it) checkpoint.save(f"{self.config['save_dir']}/Reader_train" f"_{get_timestamp()}" f"_{socket.gethostname()}" f"_{valLoss}" f"_S_{self.update_it}" f"_E_{current_it}.pt") model.train() # statistics & logging total_tokens += batch.inputSequences.numel() if (self.n_iter + 1) % 50 == 0 or current_it == len(iterator) - 1: iterator.set_description( f"Steps: {self.update_it} Tokens/s: {total_tokens / (time.time() - startTime)}, Training loss: {loss_sum / samples}" ) if self.config["max_steps"] <= self.update_it: break logging.info( f"End of epoch training loss: {loss_sum / samples:.5f}, best validation exact match: {bestExactMatch}" ) return bestExactMatch