def run_training_epoch(self): # get model model = self.get_model() # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() # model hooks if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() # track local dataloader so TPU can wrap each epoch train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device() train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device) # bookkeeping outputs = [] # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch"): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break self.batch_idx = batch_idx model.global_step = self.global_step # --------------- # RUN TRAIN STEP # --------------- _outputs = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs # detach tensors in batch_output before appending to outputs outputs.append(_recursive_detach(batch_output)) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 # update lr self.update_learning_rates(interval='step') # --------------- # RUN VAL STEP # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 can_check_val = not self.disable_validation and can_check_epoch should_check_val = is_val_check_batch or early_stop_epoch should_check_val = should_check_val or ( is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and should_check_val # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) # when logs should be saved should_save_log = ( batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: if self.proc_rank == 0 and self.logger is not None: self.logger.save() # when metrics should be logged should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch if should_log_metrics or self.fast_dev_run: # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) # --------------- # CHECKPOINTING, EARLY STOPPING # --------------- # save checkpoint even when no test or val step are defined if self.fast_dev_run or should_check_val: self.call_checkpoint_callback() if self.enable_early_stop: self.early_stop_callback.check_metrics( self.callback_metrics) # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 self.total_batch_idx += 1 # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if early_stop_epoch or self.fast_dev_run: break # process epoch outputs if isinstance( model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module if self.is_overriden('training_epoch_end', model=model): epoch_output = model.training_epoch_end(outputs) _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] self.log_metrics(log_epoch_metrics, {}) self.callback_metrics.update(callback_epoch_metrics) # in case validation step is missing and you are not running fast-dev to duplicate last batch if not self.is_overriden('validation_step') and not ( self.fast_dev_run or should_check_val): self.call_checkpoint_callback() if self.enable_early_stop: self.early_stop_callback.check_metrics(self.callback_metrics) # Epoch end events with self.profiler.profile('on_epoch_end'): # callbacks self.on_epoch_end() # model hooks if self.is_function_implemented('on_epoch_end'): model.on_epoch_end()
def train(self, model_path: Optional[str] = None): """ Main training entry point. Args: model_path (:obj:`str`, `optional`): Local path to the model if the model to train has been instantiated from a local path. If present, training will resume from the optimizer/scheduler states loaded here. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs optimizer, scheduler = self.get_optimizers(num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (model_path is not None and os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)) scheduler.load_state_dict( torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model if self.args.fp16: if not is_apex_available(): raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize( model, optimizer, opt_level=self.args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, find_unused_parameters=True, ) if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size( ) else: total_train_batch_size = (self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)) logger.info("***** Running training *****") logger.info(" Num examples = %d", self.num_examples(train_dataloader)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization stepss = %d", t_total) self.global_step = 0 self.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = self.global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = self.global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 logger.info(" Starting fine-tuning.") tr_loss = 0.0 logging_loss = 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()) for epoch in train_iterator: if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader( train_dataloader, [self.args.device]).per_device_loader(self.args.device) epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) else: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) # Reset the past mems state at the beginning of each epoch if necessary. if self.args.past_index >= 0: self._past = None for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self.training_step(model, inputs, optimizer) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if self.args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(optimizer) else: optimizer.step() scheduler.step() model.zero_grad() self.global_step += 1 self.epoch = epoch + (step + 1) / len(epoch_iterator) if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (self.global_step == 1 and self.args.logging_first_step): logs: Dict[str, float] = {} logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps # backward compatibility for pytorch schedulers logs["learning_rate"] = ( scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else scheduler.get_lr()[0]) logging_loss = tr_loss self.log(logs) if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: self.evaluate() if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.model else: assert model is self.model # Save model checkpoint output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") self.save_model(output_dir) if self.is_world_master(): self._rotate_checkpoints() if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) elif self.is_world_master(): torch.save( optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save( scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) if self.args.max_steps > 0 and self.global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and self.global_step > self.args.max_steps: train_iterator.close() break if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if self.tb_writer: self.tb_writer.close() if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(self.global_step, tr_loss / self.global_step)
def train_loop(folds, fold): if CFG.device == 'GPU': LOGGER.info(f"========== fold: {fold} training ==========") elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info(f"========== fold: {fold} training ==========") elif CFG.nprocs == 8: xm.master_print(f"========== fold: {fold} training ==========") # ==================================================== # loader # ==================================================== trn_idx = folds[folds['fold'] != fold].index val_idx = folds[folds['fold'] == fold].index train_folds = folds.loc[trn_idx].reset_index(drop=True) valid_folds = folds.loc[val_idx].reset_index(drop=True) valid_labels = valid_folds[CFG.target_cols].values train_dataset = TrainDataset(train_folds, transform=get_transforms(data='train')) valid_dataset = TrainDataset(valid_folds, transform=get_transforms(data='valid')) if CFG.device == 'GPU': train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True, drop_last=True) valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, drop_last=False) elif CFG.device == 'TPU': train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.batch_size, sampler=train_sampler, drop_last=True, num_workers=CFG.num_workers) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, sampler=valid_sampler, drop_last=False, num_workers=CFG.num_workers) # ==================================================== # scheduler # ==================================================== def get_scheduler(optimizer): if CFG.scheduler == 'ReduceLROnPlateau': scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps) elif CFG.scheduler == 'CosineAnnealingLR': scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1) elif CFG.scheduler == 'CosineAnnealingWarmRestarts': scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1) return scheduler # ==================================================== # model & optimizer # ==================================================== if CFG.device == 'TPU': device = xm.xla_device() elif CFG.device == 'GPU': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CustomResNet200D_WLF(CFG.model_name, pretrained=False) model.load_state_dict( torch.load(CFG.student, map_location=torch.device('cpu'))['model']) model.to(device) optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False) scheduler = get_scheduler(optimizer) # ==================================================== # loop # ==================================================== criterion = nn.BCEWithLogitsLoss() best_score = 0. best_loss = np.inf for epoch in range(CFG.epochs): start_time = time.time() # train if CFG.device == 'TPU': if CFG.nprocs == 1: avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device) elif CFG.nprocs == 8: para_train_loader = pl.ParallelLoader(train_loader, [device]) avg_loss = train_fn( para_train_loader.per_device_loader(device), model, criterion, optimizer, epoch, scheduler, device) elif CFG.device == 'GPU': avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device) # eval if CFG.device == 'TPU': if CFG.nprocs == 1: avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion, device) elif CFG.nprocs == 8: para_valid_loader = pl.ParallelLoader(valid_loader, [device]) avg_val_loss, preds, valid_labels = valid_fn( para_valid_loader.per_device_loader(device), model, criterion, device) preds = idist.all_gather(torch.tensor(preds)).to('cpu').numpy() valid_labels = idist.all_gather( torch.tensor(valid_labels)).to('cpu').numpy() elif CFG.device == 'GPU': avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion, device) if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(avg_val_loss) elif isinstance(scheduler, CosineAnnealingLR): scheduler.step() elif isinstance(scheduler, CosineAnnealingWarmRestarts): scheduler.step() # scoring score, scores = get_score(valid_labels, preds) elapsed = time.time() - start_time if CFG.device == 'GPU': LOGGER.info( f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s' ) LOGGER.info( f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}' ) elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info( f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s' ) LOGGER.info( f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}' ) elif CFG.nprocs == 8: xm.master_print( f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s' ) xm.master_print( f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}' ) if score > best_score: best_score = score if CFG.device == 'GPU': LOGGER.info( f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model' ) torch.save({ 'model': model.state_dict(), 'preds': preds }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth') elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info( f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model' ) elif CFG.nprocs == 8: xm.master_print( f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model' ) xm.save({ 'model': model.state_dict(), 'preds': preds }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth') if avg_val_loss < best_loss: best_loss = avg_val_loss if CFG.device == 'GPU': LOGGER.info( f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model') torch.save({ 'model': model.state_dict(), 'preds': preds }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth') elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info( f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model' ) elif CFG.nprocs == 8: xm.master_print( f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model' ) xm.save({ 'model': model.state_dict(), 'preds': preds }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth') if CFG.nprocs != 8: check_point = torch.load( OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth') for c in [f'pred_{c}' for c in CFG.target_cols]: valid_folds[c] = np.nan valid_folds[[f'pred_{c}' for c in CFG.target_cols]] = check_point['preds'] return valid_folds
def run_training_epoch(self): # get model model = self.get_model() # Epoch start events with self.profiler.profile('on_epoch_start'): # callbacks self.on_epoch_start() # model hooks if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() # track local dataloader so TPU can wrap each epoch train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device) # bookkeeping outputs = [] # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch"): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break self.batch_idx = batch_idx model.global_step = self.global_step # --------------- # RUN TRAIN STEP # --------------- _outputs = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory if self.is_overridden('training_epoch_end', model=self.get_model()): outputs.append(batch_output) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 # TODO: consolidate all actions that need to take place only after # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment) if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: # update lr self.update_learning_rates(interval='step') # --------------- # RUN VAL STEP # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 can_check_val = not self.disable_validation and can_check_epoch should_check_val = is_val_check_batch or early_stop_epoch should_check_val = should_check_val or ( is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and should_check_val # --------------- # CHECKPOINTING, EARLY STOPPING # --------------- # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) self.call_checkpoint_callback() # when logs should be saved should_save_log = ( batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: if self.proc_rank == 0 and self.logger is not None: self.logger.save() # when metrics should be logged should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch if should_log_metrics or self.fast_dev_run: # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 self.total_batch_idx += 1 # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if early_stop_epoch or self.fast_dev_run: break if self.use_horovod: hvd.join(hvd.local_rank() if self.on_gpu else -1) # process epoch outputs model = self.get_model() if self.is_overridden('training_epoch_end', model=model): epoch_output = model.training_epoch_end(outputs) _processed_outputs = self.process_output(epoch_output) log_epoch_metrics = _processed_outputs[2] callback_epoch_metrics = _processed_outputs[3] self.log_metrics(log_epoch_metrics, {}) self.callback_metrics.update(callback_epoch_metrics) self.add_progress_bar_metrics(_processed_outputs[1]) # when no val loop is present or fast-dev-run still need to call checkpoints if not self.is_overridden('validation_step') and not ( self.fast_dev_run or should_check_val): self.call_checkpoint_callback() # Epoch end events with self.profiler.profile('on_epoch_end'): # callbacks self.on_epoch_end() # model hooks if self.is_function_implemented('on_epoch_end'): model.on_epoch_end()
def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) return dataloader
def main(config_file='config/bert_config.json'): """Main method for training. Args: config_file: in config dir """ global datasets # 0. Load config and mkdir with open(config_file) as fin: config = json.load(fin, object_hook=lambda d: SimpleNamespace(**d)) get_path(os.path.join(config.model_path, config.experiment_name)) get_path(config.log_path) if config.model_type in ['rnn', 'lr', 'cnn']: # build vocab for rnn build_vocab(file_in=config.all_train_file_path, file_out=os.path.join(config.model_path, 'vocab.txt')) # 1. Load data data = Data(vocab_file=os.path.join(config.model_path, 'vocab.txt'), max_seq_len=config.max_seq_len, model_type=config.model_type, config=config) def load_dataset(): datasets = data.load_train_and_valid_files( train_file=config.train_file_path, valid_file=config.valid_file_path) return datasets if config.serial_load: datasets = SERIAL_EXEC.run(load_dataset) else: datasets = load_dataset() train_set, valid_set_train, valid_set_valid = datasets if torch.cuda.is_available(): device = torch.device('cuda') # device = torch.device('cpu') # torch.distributed.init_process_group(backend="nccl") # sampler_train = DistributedSampler(train_set) sampler_train = RandomSampler(train_set) else: device = torch.device('cpu') sampler_train = RandomSampler(train_set) # TPU device = xm.xla_device() sampler_train = torch.utils.data.distributed.DistributedSampler( train_set, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) data_loader = { 'train': DataLoader(train_set, sampler=sampler_train, batch_size=config.batch_size), 'valid_train': DataLoader(valid_set_train, batch_size=config.batch_size, shuffle=False), 'valid_valid': DataLoader(valid_set_valid, batch_size=config.batch_size, shuffle=False) } # 2. Build model # model = MODEL_MAP[config.model_type](config) model = WRAPPED_MODEL #load model states. # if config.trained_weight: # model.load_state_dict(torch.load(config.trained_weight)) model.to(device) if torch.cuda.is_available(): model = model # model = torch.nn.parallel.DistributedDataParallel( # model, find_unused_parameters=True) # 3. Train trainer = Trainer(model=model, data_loader=data_loader, device=device, config=config) # best_model_state_dict = trainer.train() if config.model_type == 'bert': no_decay = ['bias', 'gamma', 'beta'] optimizer_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay_rate': 0.01 }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay_rate': 0.0 }] optimizer = AdamW(optimizer_parameters, lr=config.lr, betas=(0.9, 0.999), weight_decay=1e-8, correct_bias=False) else: # rnn optimizer = Adam(model.parameters(), lr=config.lr) # if config.model_type == 'bert': # scheduler = get_linear_schedule_with_warmup( # optimizer, # num_warmup_steps=config.num_warmup_steps, # num_training_steps=config.num_training_steps) # else: # rnn # scheduler = get_constant_schedule(optimizer) criterion = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, batch in enumerate(loader): # batch = tuple(t.to(self.device) for t in batch) output = model(*batch[:-1]) # the last one is label loss = criterion(output, batch[-1]) loss.backward() # xm.optimizer_step(optimizer) # optimizer.zero_grad() tracker.add(FLAGS.batch_size) if (x + 1) % config.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。 xm.optimizer_step(optimizer) optimizer.zero_grad() if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print( '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None tracker = xm.RateTracker() for x, batch in enumerate(loader): output = model(*batch[:-1]) # the last one is label target = batch[-1] # pred = output.max(1, keepdim=True)[1] # correct += pred.eq(target.view_as(pred)).sum().item() for i in range(len(output)): logits = output[i] pred = int(torch.argmax(logits, dim=-1)) if pred == target[i]: correct += 1 total_samples += len(output) if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print( '[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, correct * 1.0 / total_samples, tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) accuracy = 100.0 * correct / total_samples if xm.get_ordinal() == 0: print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(FLAGS.num_epoch): para_loader = pl.ParallelLoader(data_loader['train'], [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) # para_loader = pl.ParallelLoader(data_loader['valid_train'], [device]) # accuracy_train, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) para_loader = pl.ParallelLoader(data_loader['valid_valid'], [device]) accuracy_valid, data, pred, target = test_loop_fn( para_loader.per_device_loader(device)) xm.master_print("Finished test epoch {}, valid={:.2f}".format( epoch, accuracy_valid)) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) # 4. Save model # if xm.get_ordinal() == 0: # # if epoch==FLAGS.num_epoch-1: # # WRAPPED_MODEL.to('cpu') # torch.save(WRAPPED_MODEL.state_dict(), os.path.join( # config.model_path, config.experiment_name, # config.model_type + '-' + str(epoch + 1) + '.bin')) # xm.master_print('saved model.') # WRAPPED_MODEL.to(device) return accuracy_valid
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers) # Scale learning rate to num cores lr = FLAGS.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS.metrics_debug: print(met.metrics_report()) return accuracy
def train(train_loader, model, optimizer, scheduler, epoch, args, DEVICE): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model = model.train().to(DEVICE) loader = pl.ParallelLoader(train_loader, [DEVICE]).per_device_loader(DEVICE) # noise2net = Res2Net(epsilon=0.50, hidden_planes=16, batch_size=args.batch_size).train().to(DEVICE) end = time.time() for i, (images, target) in enumerate(loader): # measure data loading time data_time.update(time.time() - end) bx = images by = target print("Zero grad") optimizer.zero_grad() # with torch.no_grad(): # if random.random() < 0.5: # batch_size = bx.shape[0] # noise2net.reload_parameters() # noise2net.set_epsilon(random.uniform(args.noisenet_max_eps / 2.0, args.noisenet_max_eps)) # bx = bx.reshape((1, batch_size * 3, 224, 224)) # bx = noise2net(bx) # bx = bx.reshape((batch_size, 3, 224, 224)) print("Forward") logits = model(bx) print("Cross Entropy") loss = F.cross_entropy(logits, by) # measure accuracy and record loss output, target = logits, by acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) print("Backward") loss.backward() print("Step") xm.optimizer_step(optimizer) print("Scheduler step") scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0 and xm.is_master_ordinal(): progress.display(i)
def main(config_file='config/bert_config.json'): """Main method for training. Args: config_file: in config dir """ # 0. Load config and mkdir with open(config_file) as fin: config = json.load(fin, object_hook=lambda d: SimpleNamespace(**d)) get_path(os.path.join(config.model_path, config.experiment_name)) get_path(config.log_path) get_path(config.prediction_path) # get_path(config.checkpoint_path) if config.model_type == 'rnn': # build vocab for rnn build_vocab(file_in=config.all_train_file_path, file_out=os.path.join(config.model_path, 'vocab.txt')) # 1. Load data data = Data(vocab_file=os.path.join(config.model_path, 'vocab.txt'), max_seq_len=config.max_seq_len, model_type=config.model_type, config=config) def load_dataset(): datasets = data.load_train_and_valid_files( train_file=config.train_file_path, valid_file=config.valid_file_path) return datasets if config.serial_load: datasets = SERIAL_EXEC.run(load_dataset) else: datasets = load_dataset() train_set, valid_set_train, valid_set_valid, train_exam, valid_exam, train_feat, valid_feat = datasets if torch.cuda.is_available(): device = torch.device('cuda') # device = torch.device('cpu') # torch.distributed.init_process_group(backend="nccl") # sampler_train = DistributedSampler(train_set) sampler_train = RandomSampler(train_set) else: device = torch.device('cpu') sampler_train = RandomSampler(train_set) # TPU device = xm.xla_device() sampler_train = torch.utils.data.distributed.DistributedSampler( train_set, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) data_loader = { 'train': DataLoader(train_set, sampler=sampler_train, batch_size=config.batch_size), 'valid_train': DataLoader(valid_set_train, batch_size=config.batch_size, shuffle=False), 'valid_valid': DataLoader(valid_set_valid, batch_size=config.batch_size, shuffle=False), 'train_exam': train_exam, 'valid_exam': valid_exam, 'train_feat': train_feat, 'valid_feat': valid_feat } # 2. Build model # TPU device = xm.xla_device() model = WRAPPED_MODEL model.to(device) if config.model_type == 'bert': no_decay = ['bias', 'gamma', 'beta'] optimizer_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay_rate': 0.01 }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay_rate': 0.0 }] optimizer = AdamW(optimizer_parameters, lr=config.lr, betas=(0.9, 0.999), weight_decay=1e-8, correct_bias=False) else: # rnn optimizer = Adam(model.parameters(), lr=config.lr) criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=IGNORE_INDEX) # 交叉熵损失 binary_criterion = nn.BCEWithLogitsLoss(reduction='mean') # 二元损失 sp_loss_fct = nn.BCEWithLogitsLoss(reduction='none') # 用于sp,平均值自己算 #load model states. # if config.trained_weight: # model.load_state_dict(torch.load(config.trained_weight)) # model.to(device) # if torch.cuda.is_available(): # model = model # model = torch.nn.parallel.DistributedDataParallel( # model, find_unused_parameters=True) # 3. Train trainer = Trainer(model=model, data_loader=data_loader, device=device, config=config) def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, batch in enumerate(loader): batch = tuple(t.to(device) for t in batch) # loss = self.criterion(logits, batch[-1]) start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model( *batch) loss1 = criterion(start_logits, batch[6]) + criterion( end_logits, batch[7]) # y1, y2 loss2 = config.type_lambda * criterion(type_logits, batch[8]) # q_type # sent_num_in_batch = batch[9].sum() # is_support # sent_num_in_batch = 1.0 + sent_num_in_batch # to avoid devide by zero # loss3 = self.sp_loss_fct(sp_logits.view(-1), batch[10].float().view(-1)).sum() * self.config.sp_lambda / sent_num_in_batch loss = loss1 + loss2 loss.backward() del batch #try to save cpu mem. tracker.add(FLAGS.batch_size) if (x + 1) % config.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。 xm.optimizer_step(optimizer) optimizer.zero_grad() if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print( '[xla:{}]({}) Loss1={:.5f} Loss2={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss1.item(), loss2.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) # def test_loop_fn(loader): # total_samples = 0 # correct = 0 # model.eval() # data, pred, target = None, None, None # tracker = xm.RateTracker() # for x, batch in enumerate(loader): # start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model(*batch) # loss1 = self.criterion(start_logits, batch[6]) + self.criterion(end_logits, batch[7])#y1,y2 # loss2 = self.config.type_lambda * self.criterion(type_logits, batch[8])#q_type # sent_num_in_batch = batch[9].sum() # start_mapping # sp_value = self.sp_loss_fct(sp_logits.view(-1), batch[10].float().view(-1)).sum() # if sent_num_in_batch != 0: # loss3 = self.config.sp_lambda * sp_value / sent_num_in_batch # else: # loss3 = self.config.sp_lambda * sp_value * 1e30 # # loss = loss1 + loss2 + loss3 # loss_list = [loss, loss1, loss2, loss3] # # for i, l in enumerate(loss_list): # if not isinstance(l, int): # total_test_loss[i] += l.item() # # batchsize = batch[0].size(0) # # ids # answer_dict_ = convert_to_tokens(exam, feats, batch[5], start_position.data.cpu().numpy().tolist(), # end_position.data.cpu().numpy().tolist(), # np.argmax(type_logits.data.cpu().numpy(), 1)) # answer_dict.update(answer_dict_) # # predict_support_np = torch.sigmoid(sp_logits).data.cpu().numpy() # for i in range(predict_support_np.shape[0]): # cur_sp_pred = [] # cur_id = batch[5][i].item() # # cur_sp_logit_pred = [] # for sp logit output # for j in range(predict_support_np.shape[1]): # if j >= len(exam[cur_id].sent_names): # break # if need_sp_logit_file: # temp_title, temp_id = exam[cur_id].sent_names[j] # cur_sp_logit_pred.append((temp_title, temp_id, predict_support_np[i, j])) # if predict_support_np[i, j] > self.config.sp_threshold: # cur_sp_pred.append(exam[cur_id].sent_names[j]) # sp_dict.update({cur_id: cur_sp_pred}) # # new_answer_dict = {} # for key, value in answer_dict.items(): # new_answer_dict[key] = value.replace(" ", "") # prediction = {'answer': new_answer_dict, 'sp': sp_dict} # with open(prediction_file, 'w', encoding='utf8') as f: # json.dump(prediction, f, indent=4, ensure_ascii=False) # # for i, l in enumerate(total_test_loss): # print("Test Loss{}: {}".format(i, l / len(dataloader))) # # if xm.get_ordinal() == 0: # if x % FLAGS.log_steps == 0: # print('[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( # xm.get_ordinal(), x, correct*1.0/total_samples, tracker.rate(), # tracker.global_rate(), time.asctime()), flush=True) # # accuracy = 100.0 * correct / total_samples # if xm.get_ordinal() == 0: # print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True) # return accuracy, data, pred, target # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(FLAGS.num_epoch): para_loader = pl.ParallelLoader(data_loader['train'], [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) # para_loader = pl.ParallelLoader(data_loader['valid_train'], [device]) # accuracy_train, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) # para_loader = pl.ParallelLoader(data_loader['valid_valid'], [device]) # accuracy_valid, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) # xm.master_print("Finished test epoch {}, valid={:.2f}".format(epoch, accuracy_valid)) # # 4. Save model # torch.save(best_model_state_dict, # os.path.join(config.model_path, 'model.bin')) if xm.get_ordinal() == 0: results = trainer.valid() xm.master_print("Finished training epoch {}".format(results))
def __init__(self, mode, opts): print('Initializing model and data source...', end='', file=stderr) stderr.flush() self.learning_rates = dict( zip(opts.learning_rate_steps, opts.learning_rate_rates)) self.opts = opts if mode == 'new': torch.manual_seed(opts.random_seed) # Initialize data dataset = data.Slice(opts) dataset.load_data(opts.dat_file) opts.training = True if opts.global_model == 'autoencoder': model = ae.AutoEncoder(opts, dataset) elif opts.global_model == 'mfcc_inverter': model = mi.MfccInverter(opts, dataset) model.post_init(dataset) dataset.post_init(model) optim = torch.optim.Adam(params=model.parameters(), lr=self.learning_rates[0]) self.state = checkpoint.State(0, model, dataset, optim) self.start_step = self.state.step else: self.state = checkpoint.State() self.state.load(opts.ckpt_file, opts.dat_file) self.start_step = self.state.step # print('Restored model, data, and optim from {}'.format(opts.ckpt_file), file=stderr) #print('Data state: {}'.format(state.data), file=stderr) #print('Model state: {}'.format(state.model.checksum())) #print('Optim state: {}'.format(state.optim_checksum())) stderr.flush() if self.state.model.bn_type == 'vae': self.anneal_schedule = dict( zip(opts.bn_anneal_weight_steps, opts.bn_anneal_weight_vals)) self.ckpt_path = util.CheckpointPath(self.opts.ckpt_template) self.quant = None self.target = None self.softmax = torch.nn.Softmax(1) # input to this is (B, Q, N) if self.opts.hwtype == 'GPU': self.device = torch.device('cuda') self.data_loader = self.state.data_loader self.data_loader.set_target_device(self.device) self.optim_step_fn = (lambda: self.state.optim.step(self.loss_fn)) self.data_iter = GPULoaderIter(iter(self.data_loader)) else: import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl self.device = xm.xla_device() self.data_loader = pl.ParallelLoader(self.state.data_loader, [self.device]) self.data_iter = TPULoaderIter(self.data_loader, self.device) self.optim_step_fn = (lambda: xm.optimizer_step( self.state.optim, optimizer_args={'closure': self.loss_fn})) self.state.init_torch_generator() print('Done.', file=stderr) stderr.flush()
def train(self, model_path=None, dev_objective=None): """ Main training entry point. The training logic is directly borrowed from transformers.Trainer (version 3.0.2). Add early stopping. """ self.best_dir = None self.objective = -float("inf") self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective # Data loading. train_dataloader = self.get_train_dataloader() num_update_steps_per_epoch = len( train_dataloader) // self.args.gradient_accumulation_steps if num_update_steps_per_epoch == 0: num_update_steps_per_epoch = 1 if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( self.args.max_steps % num_update_steps_per_epoch > 0) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs self.create_optimizer_and_scheduler(num_training_steps=t_total) optimizer = self.optimizer scheduler = self.lr_scheduler # Check if saved optimizer or scheduler states exist if (model_path is not None and os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)) scheduler.load_state_dict( torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model if self.args.fp16 and _use_apex: if not transformers.is_apex_available(): raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize( model, optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, find_unused_parameters=True, ) # Train if transformers.is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size( ) else: total_train_batch_size = (self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)) logger.info("***** Running training *****") logger.info(" Num examples = %d", self.num_examples(train_dataloader)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) self.global_step = 0 self.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = self.global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = self.global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 logger.info(" Starting fine-tuning.") tr_loss = torch.tensor(0.0).to(self.args.device) logging_loss_scalar = 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()) for epoch in train_iterator: if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) if transformers.is_torch_tpu_available(): parallel_loader = pl.ParallelLoader( train_dataloader, [self.args.device]).per_device_loader(self.args.device) epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) else: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) # Reset the past mems state at the beginning of each epoch if necessary. if self.args.past_index >= 0: self._past = None for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self.training_step(model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if self.args.fp16 and _use_native_amp: self.scaler.unscale_(optimizer) norm = torch.nn.utils.clip_grad_norm_( model.parameters(), self.args.max_grad_norm) elif self.args.fp16: norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.args.max_grad_norm) else: norm = torch.nn.utils.clip_grad_norm_( model.parameters(), self.args.max_grad_norm) if transformers.is_torch_tpu_available(): xm.optimizer_step(optimizer) elif self.args.fp16 and _use_native_amp: self.scaler.step(optimizer) self.scaler.update() else: optimizer.step() scheduler.step() model.zero_grad() self.global_step += 1 self.epoch = epoch + (step + 1) / len(epoch_iterator) if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (self.global_step == 1 and self.args.logging_first_step): logs = {} tr_loss_scalar = tr_loss.item() logs["loss"] = (tr_loss_scalar - logging_loss_scalar ) / self.args.logging_steps logs["norm"] = norm.item() # backward compatibility for pytorch schedulers logs["learning_rate"] = ( scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else scheduler.get_lr()[0]) logging_loss_scalar = tr_loss_scalar self.log(logs) # ---------------------------------------------------------------------- # BEGIN CHANGES. # ---------------------------------------------------------------------- metrics = None if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: output = self.evaluate() metrics = output.metrics objective = self.dev_objective(metrics) if objective > self.objective: logger.info( "Best dev result: {}".format(objective)) self.objective = objective self.save_model(self.args.output_dir) # ---------------------------------------------------------------------- # END CHANGES. # ---------------------------------------------------------------------- if self.args.max_steps > 0 and self.global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and self.global_step > self.args.max_steps: train_iterator.close() break if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(self.global_step, tr_loss / self.global_step), self.objective
def train(rank, args): print('enter train @ %s'%(rank), flush=True) args.rank = rank args.split = '' torch.manual_seed(42) save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt') tokenizer = get_tokenizer(args) args.vocab_size = tokenizer._tokenizer.get_vocab_size() if not args.vocab_size else args.vocab_size train_dataset = get_dataset(args) batched_already = hasattr(train_dataset, '__getbatch__') if args.total_num_updates < 100: args.total_num_updates = len(train_dataset) * args.total_num_updates if args.warmup_updates < 1: args.warmup_updates = int(args.total_num_updates * args.warmup_updates) else: args.warmup_updates = int(args.warmup_updates) train_sampler = None if args.gpus: dist.init_process_group( 'nccl', rank=rank, world_size=args.world_size ) if args.gpus > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=args.shuffle) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=args.shuffle) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size if not batched_already else None, sampler=train_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loaders = [] if args.eval_dir: for split in args.splits.split(','): split = split.strip() eval_sampler = None if args.gpus: if args.gpus > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) args.split = split eval_dataset = get_eval_dataset(args) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size if not batched_already else None, sampler=eval_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loaders.append(eval_loader) if args.gpus: assert apex_enabled torch.cuda.set_device(rank) ########################## ## ## Model Creation ## ########################## model = get_model(args, tokenizer) model.cuda(rank) device = torch.device('cuda:'+str(rank)) ########################## ## ## Init Optimizer ## ########################## optimizer = apex.optimizers.FusedAdam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay, special_layer_wise_lr=args.special_layer_wise_lr, log = rank == 0, ), # use this function to set extra optimizer arguments, # see model_get_parameters betas=(0.9, 0.999), eps=1e-6, lr=args.lr, weight_decay=args.weight_decay ) model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model = DDP(model) batches = train_loader else: assert tpu_enabled device = xm.xla_device() ########################## ## ## Model Creation ## ########################## model = get_model(args, tokenizer) ########################## ## ## For shared parameters, TPU requires modules to be tied after .to(device) ## So we first find the shared parameters first ## ########################## shared_parameters = {e[0]: e[1:] for e in _catalog_shared_params(model)} model.to(device) do_share_parameters_again(model, shared_parameters, log = rank == 0) ########################## ## ## Init Optimizer ## ########################## optimizer = optim.Adam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay ), # use this function to set extra optimizer arguments, # see model_get_parameters lr=args.lr, weight_decay=args.weight_decay ) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(args.save_dir) xm.rendezvous("load_checkpoint") # wait for all workers xm.mark_step() # tracker = xm.RateTracker() if args.restore_file: states = torch.load(args.restore_file, map_location=device) for k, v in list(states.items()): if k.startswith('module.'): del states[k] k = k[7:] states[k] = v if k.endswith('position_ids'): del states[k] states[k[:-12] + 'position_embeddings'] = v if args.gpus: states = {"module.%s"%k : v for k, v in states.items()} try: model.load_state_dict(states) except Exception as err: import traceback if rank == 0: traceback.print_exc() model.load_state_dict(states, strict=False) if rank == 0: if not os.path.exists(os.path.dirname(save_fn)): try: os.makedirs(os.path.dirname(save_fn)) except OSError as exc: # Guard against race condition if exc.errno != errno.EEXIST: raise if args.gpus: torch.save(model.state_dict(), save_fn ) else: xm.save(model.state_dict(), save_fn ) model.train() if args.anomaly_detection and rank == 0: torch.set_anomaly_enabled(True) ########################## ## ## Init LR Scheduler ## ########################## if not batched_already: args.total_num_updates = args.total_num_updates // args.batch_size args.warmup_updates = args.total_num_updates // args.batch_size args.total_num_updates = args.total_num_updates // args.world_size args.warmup_updates = args.total_num_updates // args.world_size scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_updates, num_training_steps=args.total_num_updates, ) step_i = 0 err = None tb = None #tb = SummaryWriter() try: if rank == 0: pbar = tqdm(total=args.total_num_updates, file=sys.stdout) while step_i < args.total_num_updates: if not args.gpus: batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) n_samples = len(batches) for sample in batches: step_i += 1 if step_i > args.total_num_updates: break report_step = step_i % args.log_interval == 0 while True: # the loop only for apex Gradient Overflow optimizer.zero_grad() total_loss, log = get_loss( model, sample, args=args, device=device, gpus=args.gpus, report=report_step ) if args.gpus: default_optimizer_step = optimizer.step with amp.scale_loss(total_loss, optimizer) as scaled_loss: scaled_loss.backward() # If Amp detects an overflow, it patches optimizer.step. In other words, if optimizer.step # was left unpatched, there was no overflow, and we don't need to replay. if optimizer.step is default_optimizer_step: optimizer.step() break optimizer.step() # If an overflow was detected, "optimizer.step" is the patched call, which does # nothing but restore optimizer.step to default_optimizer_step. if rank == 0: print("Overflowed, reducing loss scale and replaying batch.", flush=True) else: total_loss.backward() xm.optimizer_step(optimizer) xm.mark_step() break scheduler.step() if report_step: if 'loss' not in log: log['loss'] = total_loss # tb.add_scalar("Loss", total_loss, step_i) for k, v in log.items(): try: dist.all_reduce(v, op=dist.reduce_op.SUM) log[k] = float(v) except Exception as e: print(v, e) pass if args.gpus: if rank == 0: pbar.set_description(format_log(log, log_formatter, tb, step_i)) else: xm.add_step_closure(_train_update, args=(log, log_formatter, tb, step_i)) if args.report_metrics: xm.master_print(met.metrics_report()) if rank == 0: pbar.update(1) if rank == 0: pbar.close() if eval_loaders: model.half() model.eval() model.cuda() for k, v in model.named_parameters(): v.requires_grad =False for split, eval_loader in zip(args.splits.split(','), eval_loaders): batches = eval_loader if rank == 0: eval_length = len(batches) if not batched_already: eval_length = eval_length // args.batch_size eval_length = eval_length // args.world_size pbar = tqdm(total=eval_length, file=sys.stdout) if not args.gpus: batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device) with torch.no_grad(): record = OrderedDict() for sample in batches: evaluate( model, sample, args=args, device=device, record=record, gpus=args.gpus, report=False ) if rank == 0: pbar.update(1) for k, v in record.items(): try: def handle_reduce(v): if len(v.shape) == 0: dist.all_reduce(v, op=dist.reduce_op.SUM) else: L = [torch.ones_like(v) for _ in range(dist.get_world_size())] dist.all_gather(L, v) v = torch.car(L, dim=0) return v if isinstance(v, list): v = [handle_reduce(e) for e in v] else: v = handle_reduce(v) record[k] = float(v) except Exception as e: pass post_evaluate(record, args=args) import json if rank == 0: print('',flush=True) print('Test result for %s'%split, flush=True) print(json.dumps(record, indent=2),flush=True) print('',flush=True) except Exception as _err: err = _err finally: folder = os.path.split(os.path.abspath(save_fn))[0] os.makedirs(folder, exist_ok=True) if rank == 0: print("Saving to %s"%save_fn) if args.gpus: torch.save(model.state_dict(), save_fn ) if err: raise err else: xm.save(model.state_dict(), save_fn ) if err: raise err print("Saved to %s"%save_fn)
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader: device = xm.xla_device() dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) return dataloader
def get_tpu_dataloader(self, data_loader): return pl.ParallelLoader(data_loader, [self.device]).per_device_loader(self.device)
def _prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None) -> PredictionOutput: """ Prediction/evaluation loop, shared by `evaluate()` and `predict()`. Works both with or without labels. """ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only model = self.model # multi-gpu eval if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) else: model = self.model # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. batch_size = dataloader.batch_size logger.info("***** Running %s *****", description) logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Batch size = %d", batch_size) eval_losses: List[float] = [] preds: torch.Tensor = None label_ids: torch.Tensor = None model.eval() if is_torch_tpu_available(): dataloader = pl.ParallelLoader( dataloader, [self.args.device]).per_device_loader(self.args.device) for inputs in tqdm(dataloader, desc=description): has_labels = any( inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) for k, v in inputs.items(): inputs[k] = v.to(self.args.device) with torch.no_grad(): outputs = model(**inputs) if has_labels: step_eval_loss, logits = outputs[:2] eval_losses += [step_eval_loss.mean().item()] else: logits = outputs[0] if not prediction_loss_only: if preds is None: preds = logits.detach() else: preds = torch.cat((preds, logits.detach()), dim=0) if inputs.get("labels") is not None: if label_ids is None: label_ids = inputs["labels"].detach() else: label_ids = torch.cat( (label_ids, inputs["labels"].detach()), dim=0) if self.args.local_rank != -1: # In distributed mode, concatenate all results from all nodes: if preds is not None: preds = self.distributed_concat( preds, num_total_examples=self.num_examples(dataloader)) if label_ids is not None: label_ids = self.distributed_concat( label_ids, num_total_examples=self.num_examples(dataloader)) elif is_torch_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset if preds is not None: preds = xm.mesh_reduce("eval_preds", preds, torch.cat) if label_ids is not None: label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) # Finally, turn the aggregated tensors into numpy arrays. if preds is not None: preds = preds.cpu().numpy() if label_ids is not None: label_ids = label_ids.cpu().numpy() if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} if len(eval_losses) > 0: metrics["eval_loss"] = np.mean(eval_losses) # Prefix all keys with eval_ for key in list(metrics.keys()): if not key.startswith("eval_"): metrics[f"eval_{key}"] = metrics.pop(key) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): """ Main training entry point. Args: model_path (:obj:`str`, `optional`): Local path to the model if the model to train has been instantiated from a local path. If present, training will resume from the optimizer/scheduler states loaded here. trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): The trial run or the hyperparameter dictionary for hyperparameter search. """ # This might change the seed so needs to run first. self._hp_search_setup(trial) # Model re-init if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. set_seed(self.args.seed) model = self.model_init() self.model = model.to(self.args.device) # Reinitializes optimizer and scheduler self.optimizer, self.lr_scheduler = None, None # Data loader and number of training steps train_dataloader = self.get_train_dataloader() num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( self.args.max_steps % num_update_steps_per_epoch > 0 ) else: t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs self.args.max_steps = t_total self.create_optimizer_and_scheduler(num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if ( model_path is not None and os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt")) ): # Load in optimizer and scheduler states self.optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) ) self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model if self.args.fp16 and _use_apex: if not is_apex_available(): raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, find_unused_parameters=True, ) if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() else: total_train_batch_size = ( self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) ) logger.info("***** Running training *****") logger.info(" Num examples = %d", self.num_examples(train_dataloader)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) self.global_step = 0 self.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) epochs_trained = self.global_step // num_update_steps_per_epoch steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch) logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 logger.info(" Starting fine-tuning.") tr_loss_sum = 0.0 loss_sum = defaultdict(float) best = {self.best_metric: None} model.zero_grad() disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm) for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( self.args.device ) epoch_iterator = parallel_loader else: epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. if self.args.past_index >= 0: self._past = None epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 epoch_pbar.update(1) continue model.train() inputs = self._prepare_inputs(inputs) inputs["output_attentions"] = self.length_drop_args.length_config is not None layer_config = sample_layer_configuration( model.config.num_hidden_layers, layer_dropout_prob=self.length_drop_args.layer_dropout_prob, layer_dropout=0, ) inputs["layer_config"] = layer_config inputs["length_config"] = self.length_drop_args.length_config outputs = model(**inputs) # Save past state if it exists if self.args.past_index >= 0: self._past = outputs[self.args.past_index] task_loss = self.div_loss(outputs[0]) if self.length_drop_args.length_adaptive: loss_sum["full"] += task_loss.item() loss = task_loss if self.length_drop_args.length_adaptive: loss = loss / (self.length_drop_args.num_sandwich + 2) tr_loss_sum += loss.item() if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # inplace distillation if self.length_drop_args.length_adaptive: logits = outputs[1].detach() for i in range(self.length_drop_args.num_sandwich + 1): inputs["output_attentions"] = True layer_config = sample_layer_configuration( model.config.num_hidden_layers, layer_dropout_prob=self.length_drop_args.layer_dropout_prob, layer_dropout=(self.length_drop_args.layer_dropout_bound if i == 0 else None), layer_dropout_bound=self.length_drop_args.layer_dropout_bound, ) inputs["layer_config"] = layer_config length_config = sample_length_configuration( self.args.max_seq_length, model.config.num_hidden_layers, layer_config, length_drop_ratio=(self.length_drop_args.length_drop_ratio_bound if i == 0 else None), length_drop_ratio_bound=self.length_drop_args.length_drop_ratio_bound, ) inputs["length_config"] = length_config outputs_sub = model(**inputs) task_loss_sub = self.div_loss(outputs_sub[0]) if i == 0: loss_sum["smallest"] += task_loss_sub.item() loss_sum["sub"] += 0 else: loss_sum["sub"] += task_loss_sub.item() / self.length_drop_args.num_sandwich logits_sub = outputs_sub[1] loss_fct = KLDivLoss(reduction="batchmean") kl_loss = loss_fct(F.log_softmax(logits, -1), F.softmax(logits_sub, -1)) loss = self.div_loss(kl_loss) loss_sum["kl"] += loss.item() / (self.length_drop_args.num_sandwich + 1) loss = loss / (self.length_drop_args.num_sandwich + 2) tr_loss_sum += loss.item() if self.args.fp16 and _use_native_amp: self.scaler.scale(loss).backward() elif self.args.fp16 and _use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps (step + 1) == len(epoch_iterator) <= self.args.gradient_accumulation_steps ): if self.args.fp16 and _use_native_amp: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) elif self.args.fp16 and _use_apex: torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(self.optimizer) elif self.args.fp16 and _use_native_amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.lr_scheduler.step() model.zero_grad() self.global_step += 1 self.epoch = epoch + (step + 1) / len(epoch_iterator) if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( self.global_step == 1 and self.args.logging_first_step ): # backward compatibility for pytorch schedulers lr = ( self.lr_scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else self.lr_scheduler.get_lr()[0] ) loss = tr_loss_sum / self.args.logging_steps tr_loss_sum = 0.0 logs = {"lr": lr, "loss": loss} log_str = f"[{self.global_step:5d}] lr {lr:g} | loss {loss:2.3f}" for key, value in loss_sum.items(): value /= self.args.logging_steps loss_sum[key] = 0.0 logs[f"{key}_loss"] = value log_str += f" | {key}_loss {value:2.3f}" self.log(logs, "train") logger.info(log_str) ''' if ( self.args.evaluation_strategy == EvaluationStrategy.STEPS and self.global_step % self.args.eval_steps == 0 ): results = self.evaluate() self._report_to_hp_search(trial, epoch, results) ''' if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert ( model.module is self.model ), f"Module {model.module} should be a reference to self.model" else: assert model is self.model, f"Model {model} should be a reference to self.model" if self.args.evaluate_during_training: results = self.evaluate() results = {k[5:]: v for k, v in results.items() if k.startswith("eval_")} self.log(results, "dev") msg = " | ".join([f"{k} {v:.3f}" for k, v in results.items()]) logger.info(f" [{self.global_step:5d}] {msg}") # Save model checkpoint if self.args.save_only_best: output_dirs = [] else: checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}" if self.hp_search_backend is not None and trial is not None: run_id = ( trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() ) checkpoint_folder += f"-run-{run_id}" output_dirs = [os.path.join(self.args.output_dir, checkpoint_folder)] if self.args.evaluate_during_training: if best[self.best_metric] is None or results[self.best_metric] > best[self.best_metric]: logger.info("Congratulations, best model so far!") output_dirs.append(os.path.join(self.args.output_dir, "checkpoint-best")) best = results for output_dir in output_dirs: self.save_model(output_dir) if self.is_world_master() and self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) if self.is_world_process_zero(): self._rotate_checkpoints(use_mtime=True) ''' if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) elif self.is_world_process_zero(): torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) ''' epoch_pbar.update(1) if 0 < self.args.max_steps <= self.global_step: break epoch_pbar.close() train_pbar.update(1) ''' if self.args.evaluation_strategy == EvaluationStrategy.EPOCH: results = self.evaluate() self._report_to_hp_search(trial, epoch, results) ''' if self.args.tpu_metrics_debug or self.args.debug: if is_torch_tpu_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: logger.warning( "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) if 0 < self.args.max_steps <= self.global_step: break train_pbar.close() if self.tb_writer: self.tb_writer.close() if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") return self.global_step, best
def train(self, data_loader): losses = AverageMeter() self.model.train() print_idx = int(len(data_loader) * self.tpu_print / 100) if self.accumulation_steps > 1: self.optimizer.zero_grad() if self.use_tpu: para_loader = pl.ParallelLoader(data_loader, [self.device]) tk0 = para_loader.per_device_loader(self.device) else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): if self.accumulation_steps == 1 and b_idx == 0: self.optimizer.zero_grad() if self.model_fn is None: for key, value in data.items(): data[key] = value.to(self.device) _, loss = self.model(**data) else: loss = self.model_fn(data, self.device, self.model) if not self.use_tpu: with torch.set_grad_enabled(True): if self.use_mean_loss: loss = loss.mean() if self.fp16: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (b_idx + 1) % self.accumulation_steps == 0: self.optimizer.step() if self.scheduler is not None: self.scheduler.step() if b_idx > 0: self.optimizer.zero_grad() else: loss.backward() xm.optimizer_step(self.optimizer) if self.scheduler is not None: self.scheduler.step() if b_idx > 0: self.optimizer.zero_grad() if self.use_tpu: reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn) losses.update(reduced_loss.item(), data_loader.batch_size) else: losses.update(loss.item(), data_loader.batch_size) if not self.use_tpu: tk0.set_postfix(loss=losses.avg) else: if b_idx % print_idx == 0 or b_idx == len(data_loader): xm.master_print( f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}" ) if not self.use_tpu: tk0.close() return losses.avg
def train_model(index): device = xm.xla_device() torch.manual_seed(0) if not os.path.exists(tokenized_data_path): os.mkdir(tokenized_data_path) if not args.pretrained_model: model = transformers.modeling_gpt2.GPT2LMHeadModel( config=model_config) else: model = transformers.modeling_gpt2.GPT2LMHeadModel( config=model_config) model.load_state_dict(torch.load(output_dir + 'final_model')) model.train() model.to(device) multi_gpu = False full_len = 0 # print('calculating total steps') # for i in tqdm(range(num_pieces)): # with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: # full_len += len([int(item) for item in f.read().strip().split()]) # total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation) # print('total steps = {}'.format(total_steps)) optimizer = transformers.AdamW(model.parameters(), lr=lr, correct_bias=True) # scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps) # if fp16: # try: # from apex import amp # except ImportError: # raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") # model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) # if torch.cuda.device_count() > 1: # print("Let's use", torch.cuda.device_count(), "GPUs!") # model = DataParallel(model) # multi_gpu = True if xm.is_master_ordinal(): print('starting training') doc_size = 10 raw_data_batch_len = len(raw_data_files) // doc_size for epoch in range(epochs): if xm.is_master_ordinal(): print('epoch {}'.format(epoch + 1)) now = datetime.now() print('time: {}'.format(now)) for batch_len in range(raw_data_batch_len): train_dataset = TextDataset( raw_data_files[batch_len * doc_size:(batch_len + 1) * doc_size], tokenized_data_path, full_tokenizer, n_ctx) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) # Creates dataloaders, which load data in batches # Note: test loader is not shuffled or sampled train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=8, drop_last=True) # tokens = get_tokenization(raw_data_file, tokenized_data_path, full_tokenizer) # if tokens is None: # continue # start_point = 0 # samples = [] # while start_point < len(tokens) - n_ctx: # samples.append(tokens[start_point: start_point + n_ctx]) # start_point += stride # if start_point < len(tokens): # samples.append(tokens[len(tokens) - n_ctx:]) # random.shuffle(samples) para_train_loader = pl.ParallelLoader( train_loader, [device]).per_device_loader(device) running_loss = 0 for step, batch_inputs in enumerate(para_train_loader): # for step in range(len(samples) // batch_size): # prepare data # batch = samples[step * batch_size: (step + 1) * batch_size] # batch_labels = [] # batch_inputs = [] # for ids in batch: # int_ids_for_labels = [int(x) for x in ids] # int_ids_for_inputs = [int(x) for x in ids] # batch_labels.append(int_ids_for_labels) # batch_inputs.append(int_ids_for_inputs) # print(batch_inputs) batch_inputs = batch_inputs.to(device) # print(batch_labels.size(), batch_inputs.size()) # forward pass outputs = model.forward(input_ids=batch_inputs, labels=batch_inputs) loss, logits = outputs[:2] # get loss # if multi_gpu: # loss = loss.mean() # if gradient_accumulation > 1: # loss = loss / gradient_accumulation optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) xm.optimizer_step(optimizer) # if (step + 1) % gradient_accumulation == 0: # running_loss += loss.item() # optimizer.step() # xm.optimizer_step(optimizer) # optimizer.zero_grad() # scheduler.step() if xm.is_master_ordinal(): if (step + 1) % log_step == 0: print( 'now time: {}:{}. Step {}/{} of pice {}/{} epoch {}, loss {}' .format(datetime.now().hour, datetime.now().minute, (step + 1), len(para_train_loader), batch_len + 1, raw_data_batch_len, epoch + 1, running_loss / log_step)) running_loss = 0 else: running_loss += loss.item() xm.save(model.state_dict(), output_dir + 'final_model') if xm.is_master_ordinal(): gc.collect()
def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False): """Run evaluation code. Args: model: PT model dataloaders: list of PT dataloaders max_batches: Scalar test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when on fast_dev_run (sets max_batch=1) if batch_idx >= max_batches: break # callbacks if test_mode: self.on_test_batch_start() else: self.on_validation_batch_start() # ----------------- # RUN EVALUATION STEP # ----------------- if self.use_amp and self.use_native_amp: with torch.cuda.amp.autocast(): output = self.evaluation_forward( model, batch, batch_idx, dataloader_idx, test_mode) else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) # on dp / ddp2 might still want to do something with the batch parts if test_mode: if self.is_overridden('test_step_end'): model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) self.on_test_batch_end() else: if self.is_overridden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) self.on_validation_batch_end() # track outputs for collation dl_outputs.append(output) outputs.append(dl_outputs) eval_results = {} # with a single dataloader don't pass an array if len(dataloaders) == 1: outputs = outputs[0] # give model a chance to do something with the outputs (and method defined) if isinstance( model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module if test_mode: if self.is_overridden('test_end', model=model): # TODO: remove in v1.0.0 eval_results = model.test_end(outputs) rank_zero_warn( 'Method `test_end` was deprecated in v0.7 and will be removed v1.0.' ' Use `test_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('test_epoch_end', model=model): eval_results = model.test_epoch_end(outputs) else: if self.is_overridden('validation_end', model=model): # TODO: remove in v1.0.0 eval_results = model.validation_end(outputs) rank_zero_warn( 'Method `validation_end` was deprecated in v0.7 and will be removed v1.0.' ' Use `validation_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('validation_epoch_end', model=model): eval_results = model.validation_epoch_end(outputs) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) return eval_results
def train(args, model, tokenizer): """ Train the model """ if xm.is_master_ordinal(): tb_writer = SummaryWriterP(args.output_dir) def summary_write(*args, **kwargs): if xm.is_master_ordinal(): tb_writer.add_scalar(*args, **kwargs) args.train_batch_size = args.per_gpu_train_batch_size #* max(1, args.n_gpu) train_dataloader = build_dataloader(args, tokenizer) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] # Scale learning rate to num cores #args.learning_rate = args.learning_rate * xm.xrt_world_size() if args.sgd: optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate) else: optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) warmup_steps = args.warmup_samples // (args.train_batch_size * xm.xrt_world_size()) if args.lr_decay: scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total) elif args.lr_cosine: scheduler = WarmupCosineWithHardRestartsSchedule( optimizer, warmup_steps=warmup_steps, t_total=t_total, cycles=args.num_train_epochs) else: scheduler = WarmupZeroSchedule(optimizer, warmup_steps=warmup_steps) # Train! tracker = xm.RateTracker() log_info("***** Running training *****") log_info(" Num Epochs = %d", args.num_train_epochs) log_info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) log_info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (xm.xrt_world_size() if args.local_rank != -1 else 1)) log_info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) log_info(" Total optimization steps = %d", t_total) try: with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c: global_step = int(c.readline()) except OSError as e: global_step = 0 moving_loss = MovingLoss(10000 // args.logging_steps) train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=not xm.is_master_ordinal()) try: for epoch in train_iterator: p_train_dataloader = pl.ParallelLoader(train_dataloader, [args.device]) epoch_iterator = tqdm(p_train_dataloader.per_device_loader( args.device), total=len(train_dataloader), desc="Iteration", disable=not xm.is_master_ordinal()) model.train() for step, batch in enumerate(epoch_iterator): optimizer.zero_grad() inputs, labels = mask_tokens( batch, tokenizer, args) if args.mlm else (batch, batch) outputs = model( inputs, masked_lm_labels=labels) if args.mlm else model( inputs, labels=labels) loss = outputs[ 0] # model outputs are always tuple in pytorch-transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) xm.optimizer_step(optimizer, barrier=True) scheduler.step() global_step += 1 tracker.add(args.train_batch_size) if args.logging_steps > 0 and global_step % args.logging_steps == 0: ls = loss.item( ) # weird. if you call loss.item() only in one process, the whole thing hangs. So call on every and log in one. moving_loss.add(ls) summary_write('lr', scheduler.get_last_lr()[0], global_step) epoch_iterator.set_postfix( MovingLoss=f'{moving_loss.loss:.2f}', Perplexity= f'{torch.exp(torch.tensor(moving_loss.loss)):.2f}') if args.save_steps > 0 and global_step % args.save_steps == 0: save_state(args, model, tokenizer, global_step) if step >= 2: # TPU seems to like consistent epoch lenght if xm.is_master_ordinal(): print(met.metrics_report()) exit(0) # epoch_iterator.close() # break if args.max_steps > 0 and step > args.max_steps: epoch_iterator.close() break # evaluate once in an epoch if args.evaluate_during_training: results = evaluate(args, model, tokenizer, f"checkpoint-{global_step}") log_info(f"Eval {results}") for key, value in results.items(): summary_write("eval_{}".format(key), value, global_step) # reload dataset every args.reload_data_file epochs if args.reload_data_file and (epoch + 1) % args.reload_data_file == 0: train_dataloader = build_dataloader(args, tokenizer) # that's very slow on TPU #print_sample(model, tokenizer, args.device, args) except (KeyboardInterrupt, SystemExit): save_state(args, model, tokenizer, global_step) raise save_state(args, model, tokenizer, global_step) return global_step, moving_loss.loss
def train_fn(df): size = 1; torch.manual_seed(42) df = shuffle(df) split = np.int32(SPLIT*len(df)) val_df, train_df = df[split:], df[:split] val_df = val_df.reset_index(drop=True) val_set = QuoraDataset(val_df, tokenizer) val_sampler = DistributedSampler(val_set, num_replicas=8, rank=xm.get_ordinal(), shuffle=True) train_df = train_df.reset_index(drop=True) train_set = QuoraDataset(train_df, tokenizer) train_sampler = DistributedSampler(train_set, num_replicas=8, rank=xm.get_ordinal(), shuffle=True) val_loader = DataLoader(val_set, VAL_BATCH_SIZE, sampler=val_sampler, num_workers=0, drop_last=True) train_loader = DataLoader(train_set, BATCH_SIZE, sampler=train_sampler, num_workers=0, drop_last=True) device = xm.xla_device() network = Roberta().to(device) optimizer = Adam([{'params': network.roberta.parameters(), 'lr': LR[0]*size}, {'params': network.dense_output.parameters(), 'lr': LR[1]*size}]) val_losses, val_f1s = [], [] train_losses, train_f1s = [], [] start = time.time() xm.master_print("STARTING TRAINING ...\n") for epoch in range(EPOCHS): batch = 1 network.train() fonts = (fg(48), attr('reset')) xm.master_print(("EPOCH %s" + str(epoch+1) + "%s") % fonts) val_parallel = pl.ParallelLoader(val_loader, [device]).per_device_loader(device) train_parallel = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) for train_batch in train_parallel: train_targ, train_in, train_att = train_batch network = network.to(device) train_in = train_in.to(device) train_att = train_att.to(device) train_targ = train_targ.to(device) train_preds = network.forward(train_in, train_att) train_loss = bce(train_preds, train_targ)/len(train_preds) train_f1 = f1_score(train_preds, train_targ.squeeze(dim=1)) optimizer.zero_grad() train_loss.backward() xm.optimizer_step(optimizer) end = time.time() batch = batch + 1 is_print = batch % 10 == 1 f1 = np.round(train_f1.item(), 3) if is_print: print_metric(f1, batch, None, start, end, metric="F1", typ="Train") val_loss, val_f1, val_points = 0, 0, 0 network.eval() with torch.no_grad(): for val_batch in val_parallel: targ, val_in, val_att = val_batch targ = targ.to(device) val_in = val_in.to(device) val_att = val_att.to(device) network = network.to(device) pred = network.forward(val_in, val_att) val_points += len(targ) val_loss += bce(pred, targ).item() val_f1 += f1_score(pred, targ.squeeze(dim=1)).item()*len(pred) end = time.time() val_f1 /= val_points val_loss /= val_points f1 = xm.mesh_reduce('f1', val_f1, lambda x: sum(x)/len(x)) loss = xm.mesh_reduce('loss', val_loss, lambda x: sum(x)/len(x)) print_metric(np.round(f1, 3), None, epoch, start, end, metric="F1", typ="Val") xm.master_print("") val_f1s.append(f1); train_f1s.append(train_f1.item()) val_losses.append(loss); train_losses.append(train_loss.item()) xm.master_print("ENDING TRAINING ...") xm.save(network.state_dict(), MODEL_SAVE_PATH); del network; gc.collect() metric_lists = [val_losses, train_losses, val_f1s, train_f1s] metric_names = ['val_loss_', 'train_loss_', 'val_f1_', 'train_f1_'] for i, metric_list in enumerate(metric_lists): for j, metric_value in enumerate(metric_list): torch.save(metric_value, metric_names[i] + str(j) + '.pt')
def evaluate(self, model, dataloaders, max_batches, test=False): """Run evaluation code. :param model: PT model :param dataloaders: list of PT dataloaders :param max_batches: Scalar :param test: boolean :return: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device() dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) for batch_idx, batch in enumerate(dataloader): if batch is None: # pragma: no cover continue # stop short when on fast_dev_run (sets max_batch=1) if batch_idx >= max_batches: break # ----------------- # RUN EVALUATION STEP # ----------------- output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test) # track outputs for collation dl_outputs.append(output) # batch done if batch_idx % self.progress_bar_refresh_rate == 0: if test: self.test_progress_bar.update( self.progress_bar_refresh_rate) else: self.val_progress_bar.update( self.progress_bar_refresh_rate) self.main_progress_bar.update( self.progress_bar_refresh_rate) outputs.append(dl_outputs) eval_results = {} # with a single dataloader don't pass an array if len(dataloaders) == 1: outputs = outputs[0] # give model a chance to do something with the outputs (and method defined) model = self.get_model() if test and self.is_overriden('test_end'): eval_results = model.test_end(outputs) elif self.is_overriden('validation_end'): eval_results = model.validation_end(outputs) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) return eval_results
def run_training_epoch(self): # Epoch begin callbacks self.on_epoch_start() # before epoch hook if self.is_function_implemented('on_epoch_start'): model = self.get_model() with self.profiler.profile('on_epoch_start'): model.on_epoch_start() # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(self.get_model()) # track local dataloader so TPU can wrap each epoch train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device() train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device) # run epoch for batch_idx, batch in self.profiler.profile_iterable( enumerate(train_dataloader), "get_train_batch"): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break self.batch_idx = batch_idx model = self.get_model() model.global_step = self.global_step # --------------- # RUN TRAIN STEP # --------------- output = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics = output # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 # --------------- # RUN VAL STEP # --------------- is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 should_check_val = not self.disable_validation and can_check_epoch should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch) # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) if self.enable_early_stop: self.early_stop_callback.check_metrics( self.callback_metrics) # when logs should be saved should_save_log = ( batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: if self.proc_rank == 0 and self.logger is not None: self.logger.save() # when metrics should be logged should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch if should_log_metrics or self.fast_dev_run: # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 self.total_batch_idx += 1 # max steps reached, end training if self.max_steps is not None and self.max_steps == self.global_step: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if early_stop_epoch or self.fast_dev_run: break # epoch end hook if self.is_function_implemented('on_epoch_end'): model = self.get_model() with self.profiler.profile('on_epoch_end'): model.on_epoch_end() # Epoch begin callbacks self.on_epoch_end()
def run(): torch.manual_seed(seed) device = xm.xla_device() model = MX.to(device) # DataLoaders train_dataset = TweetDataset(args=args, df=train_df, mode="train", fold=args.fold_index, tokenizer=tokenizer) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, drop_last=False, num_workers=2) valid_dataset = TweetDataset(args=args, df=train_df, mode="valid", fold=args.fold_index, tokenizer=tokenizer) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, sampler=valid_sampler, num_workers=1, drop_last=False) param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_parameters = [ { 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.001 }, { 'params': [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }, ] num_train_steps = int(num_train_dpoints / args.batch_size / xm.xrt_world_size() * args.epochs) optimizer = AdamW(optimizer_parameters, lr=args.learning_rate * xm.xrt_world_size()) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=num_train_steps) xm.master_print("Training is Starting ...... ") best_jac = 0 #early_stopping = utils.EarlyStopping(patience=2, mode="max", verbose=True) for epoch in range(args.epochs): para_loader = pl.ParallelLoader(train_loader, [device]) train_loss = train(args, para_loader.per_device_loader(device), model, device, optimizer, scheduler, epoch, f) para_loader = pl.ParallelLoader(valid_loader, [device]) valid_jac = valid(args, para_loader.per_device_loader(device), model, device, tokenizer, epoch, f) jac = xm.mesh_reduce("jac_reduce", valid_jac, reduce_fn) xm.master_print(f"**** Epoch {epoch+1} **==>** Jaccard = {jac}") log_ = f"**** Epoch {epoch+1} **==>** Jaccard = {jac}" f.write(log_ + "\n\n") if jac > best_jac: xm.master_print("**** Model Improved !!!! Saving Model") xm.save( model.state_dict(), os.path.join(args.save_path, f"fold_{args.fold_index}")) best_jac = jac
def train_tpu(): torch.manual_seed(1) def get_dataset(): fold_number = 0 train_ = pd.read_csv(args.train_fold) train = ShopeeDataset( train_[train_['fold'] != fold_number].reset_index(drop=True)) test = ShopeeDataset( train_[train_['fold'] != fold_number].reset_index(drop=True), transform=args.test_args) return train, test # Using the serial executor avoids multiple processes # to download the same data. train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, drop_last=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=True) # Scale learning rate to num cores learning_rate = 1e-5 * xm.xrt_world_size() # Get loss function, optimizer, and model device = xm.xla_device() model = WRAPPED_MODEL.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, label) in enumerate(loader): optimizer.zero_grad() output = model(image=data, label=label, get_embedding=args.get_embeddings) loss = loss_fn(output, label) loss.backward() xm.optimizer_step(optimizer) tracker.add(args.batch_size) if x % 20 == 0: print( '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) def test_loop_fn(loader): model.eval() for x, (data, label) in enumerate(loader): output = model(image=data, label=label, get_embedding=args.get_embeddings) loss = loss_fn(output, label) if x % 20 == 0: print('[xla:{}]({}) Loss={:.5f}'.format( xm.get_ordinal(), x, loss.item()), flush=True) for epoch in range(1, args.n_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) test_loop_fn(para_loader.per_device_loader(device))
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(met.metrics_report()) test_utils.close_summary_writer(writer) return accuracy
def prediction_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None) -> PredictionOutput: """ Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. Works both with or without labels. """ if hasattr(self, "_prediction_loop"): warnings.warn( "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.", FutureWarning, ) return self._prediction_loop( dataloader, description, prediction_loss_only=prediction_loss_only) prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only model = self.model # multi-gpu eval if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) else: model = self.model # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. batch_size = dataloader.batch_size logger.info("***** Running %s *****", description) logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Batch size = %d", batch_size) eval_losses: List[float] = [] preds: torch.Tensor = None label_ids: torch.Tensor = None model.eval() if is_torch_tpu_available(): dataloader = pl.ParallelLoader( dataloader, [self.args.device]).per_device_loader(self.args.device) if self.args.past_index >= 0: self._past = None for inputs in tqdm(dataloader, desc=description): loss, logits, labels = self.prediction_step( model, inputs, prediction_loss_only) if loss is not None: eval_losses.append(loss) if logits is not None: preds = logits if preds is None else torch.cat( (preds, logits), dim=0) if labels is not None: label_ids = labels if label_ids is None else torch.cat( (label_ids, labels), dim=0) if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") if self.args.local_rank != -1: # In distributed mode, concatenate all results from all nodes: if preds is not None: preds = self.distributed_concat( preds, num_total_examples=self.num_examples(dataloader)) if label_ids is not None: label_ids = self.distributed_concat( label_ids, num_total_examples=self.num_examples(dataloader)) elif is_torch_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset if preds is not None: preds = xm.mesh_reduce("eval_preds", preds, torch.cat) if label_ids is not None: label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) # Finally, turn the aggregated tensors into numpy arrays. if preds is not None: preds = preds.cpu().numpy() if label_ids is not None: label_ids = label_ids.cpu().numpy() if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics( EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} if len(eval_losses) > 0: metrics["eval_loss"] = np.mean(eval_losses) # Prefix all keys with eval_ for key in list(metrics.keys()): if not key.startswith("eval_"): metrics[f"eval_{key}"] = metrics.pop(key) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def run(): ''' Entire training loop - Create DataLoaders - Define Training Configuration - Launch Training Loop ''' # Num of available TPU cores if config.TPUs: n_TPUs = xm.xrt_world_size() DEVICE = xm.xla_device() else: DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' device = torch.device(DEVICE) # Read Data # df1 = pd.read_csv('data/jigsaw-toxic-comment-train.csv', usecols=['comment_text', 'toxic']) # df2 = pd.read_csv('data/jigsaw-unintended-bias-train.csv', usecols=['comment_text', 'toxic'], engine='python') # don't know why it was breaking with default C parser # df_train = df1 # pd.concat([df1,df2], axis=0).reset_index(drop=True) # df_valid = pd.read_csv('data/validation.csv') # Subsample df_train = pd.read_csv('data/jigsaw-toxic-comment-train-small.csv', usecols=['comment_text', 'toxic']) df_valid = pd.read_csv('data/validation-small.csv', usecols=['comment_text', 'toxic']) # Preprocess train_dataset = dataset.BERTDataset( comment=df_train.comment_text.values, target=df_train.toxic.values ) valid_dataset = dataset.BERTDataset( comment=df_valid.comment_text.values, target=df_valid.toxic.values ) drop_last=False train_sampler, valid_sampler = None, None if config.TPUs: drop_last=True train_sampler = DistributedSampler( train_dataset, num_replicas=n_TPUs, rank=xm.get_ordinal(), shuffle=True ) valid_sampler = DistributedSampler( valid_dataset, num_replicas=n_TPUs, rank=xm.get_ordinal(), shuffle=True ) # Create Data Loaders train_data_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=config.TRAIN_BATCH_SIZE, num_workers=4, drop_last=drop_last, sampler=train_sampler ) valid_data_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=1, drop_last=drop_last, sampler=valid_sampler ) # Machine Configuration if config.MODEL == 'bert': model = BERTBaseUncased() elif config.MODEL == 'distil-bert': model = DistilBERTBaseUncased() else: print('Model chosen in config not valid') exit() model.to(device) # Optimizer Configuration param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, ] lr = config.LR num_train_steps = int(len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS) # TODO: why do the LR increases because of a distributed training ? if config.TPUs: num_train_steps /= n_TPUs lr *= n_TPUs optimizer = AdamW(optimizer_parameters, lr=lr) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=num_train_steps ) if not config.TPUs: if N_GPU > 1: model = nn.DataParallel(model) # Training loop best_score = 0 for epoch in range(config.EPOCHS): if config.TPUs: train_loader = pl.ParallelLoader(train_data_loader, [device]) valid_loader = pl.ParallelLoader(valid_data_loader, [device]) train_fn(train_loader.per_device_loader(device), model, optimizer, device, scheduler) outputs, targets = eval_fn(valid_loader.per_device_loader(device), model, device) else: train_fn(train_data_loader, model, optimizer, device, scheduler) outputs, targets = eval_fn(valid_data_loader, model, device) targets = np.array(targets) >= 0.5 # TODO: why ? auc_score = metrics.roc_auc_score(targets, outputs) # Save if best print(f"AUC Score = {auc_score}") if auc_score > best_score: if not config.TPUs: torch.save(model.state_dict(), config.MODEL_PATH) else: xm.save(model.state_dict(), config.MODEL_PATH) best_score = auc_score
def train_mnist(): torch.manual_seed(1) """ tpu 를 쓴다하면 dataset 에 할 일 train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) """ def get_dataset(): norm = transforms.Normalize((0.1307,), (0.3081,)) train_dataset = datasets.MNIST( FLAGS['datadir'], train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), norm])) test_dataset = datasets.MNIST( FLAGS['datadir'], train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), norm])) return train_dataset, test_dataset # Using the serial executor avoids multiple processes to # download the same data. train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS['batch_size'], sampler=train_sampler, num_workers=FLAGS['num_workers'], drop_last=True) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=FLAGS['num_workers'], drop_last=True) # Scale learning rate to world size lr = FLAGS['learning_rate'] * xm.xrt_world_size() # Get loss function, optimizer, and model """ tpu 쓴다하면 device 가 device = xm.xla_device() model = xmp.MpModelWrapper(MNIST()).to(device) """ device = xm.xla_device() model = WRAPPED_MODEL.to(device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum']) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() # tpu 쓴다하면 optimizer 에 xm.optimizer_step(optimizer) xm.optimizer_step(optimizer) tracker.add(FLAGS['batch_size']) if x % FLAGS['log_steps'] == 0: print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples print('[xla:{}] Accuracy={:.2f}%'.format( xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(1, FLAGS['num_epochs'] + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS['metrics_debug']: xm.master_print(met.metrics_report(), flush=True) return accuracy, data, pred, target
def train(rank, args): print('enter train @ %s' % (rank), flush=True) args.rank = rank torch.manual_seed(42) tokenizer = get_tokenizer(args) args.vocab_size = tokenizer._tokenizer.get_vocab_size() train_dataset = get_dataset(args) if args.total_num_updates < 100: args.total_num_updates = len(train_dataset) * args.total_num_updates if args.warmup_updates < 1: args.warmup_updates = int(args.total_num_updates * args.warmup_updates) else: args.warmup_updates = int(args.warmup_updates) train_sampler = None if args.gpus: dist.init_process_group('nccl', rank=rank, world_size=args.world_size) if args.gpus > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size if not hasattr(train_dataset, '__getbatch__') else None, sampler=train_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) eval_loader = None if args.eval_dir: eval_sampler = None if args.gpus: dist.init_process_group('nccl', rank=rank, world_size=args.world_size) if args.gpus > 1: traieval_samplern_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.gpus, rank=rank, shuffle=False) else: rank = xm.get_ordinal() if xm.xrt_world_size() > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False) eval_dataset = get_eval_dataset(args) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=args.batch_size if not hasattr(train_dataset, '__getbatch__') else None, sampler=eval_sampler, pin_memory=True, shuffle=False, num_workers=args.num_workers) if args.gpus: assert apex_enabled torch.cuda.set_device(rank) ########################## ## ## Model Creation ## ########################## model = get_model(args) model.cuda(rank) device = torch.device('cuda:' + str(rank)) ########################## ## ## Init Optimizer ## ########################## optimizer = apex.optimizers.FusedAdam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay), # use this function to set extra optimizer arguments, # see model_get_parameters betas=(0.9, 0.999), eps=1e-6, lr=args.lr, weight_decay=args.weight_decay) model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model = DDP(model) batches = train_loader else: assert tpu_enabled device = xm.xla_device() ########################## ## ## Model Creation ## ########################## model = get_model(args) ########################## ## ## For shared parameters, TPU requires modules to be tied after .to(device) ## So we first find the shared parameters first ## ########################## shared_parameters = { e[0]: e[1:] for e in _catalog_shared_params(model) } model.to(device) do_share_parameters_again(model, shared_parameters, log=rank == 0) ########################## ## ## Init Optimizer ## ########################## optimizer = optim.Adam( model_get_parameters(model, lr=args.lr, lw_lr_decay=args.lw_lr_decay, weight_decay=args.weight_decay), # use this function to set extra optimizer arguments, # see model_get_parameters lr=args.lr, weight_decay=args.weight_decay) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(args.save_dir) xm.rendezvous("load_checkpoint") # wait for all workers xm.mark_step() # tracker = xm.RateTracker() if args.restore_file: states = torch.load(args.restore_file, map_location=device) for k, v in list(states.items()): if k.startswith('module.'): del states[k] k = k[7:] states[k] = v if k.endswith('position_ids'): del states[k] states[k[:-12] + 'position_embeddings'] = v try: model.load_state_dict(states) except Exception as err: import traceback traceback.print_exc() model.load_state_dict(states, strict=False) model.train() if args.anomaly_detection and rank == 0: torch.set_anomaly_enabled(True) ########################## ## ## Init LR Scheduler ## ########################## scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_updates, num_training_steps=args.total_num_updates, ) step_i = 0 err = None try: if rank == 0: pbar = tqdm(total=args.total_num_updates) while step_i < args.total_num_updates: if not args.gpus: batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) for sample in batches: step_i += 1 if step_i > args.total_num_updates: break report_step = step_i % args.log_interval == 0 while True: # the loop only for apex Gradient Overflow optimizer.zero_grad() total_loss, log = get_loss(model, sample, args=args, device=device, gpu=args.gpus, report=report_step) if args.gpus: default_optimizer_step = optimizer.step with amp.scale_loss(total_loss, optimizer) as scaled_loss: scaled_loss.backward() # If Amp detects an overflow, it patches optimizer.step. In other words, if optimizer.step # was left unpatched, there was no overflow, and we don't need to replay. if optimizer.step is default_optimizer_step: optimizer.step() break optimizer.step( ) # If an overflow was detected, "optimizer.step" is the patched call, which does # nothing but restore optimizer.step to default_optimizer_step. if rank == 0: print( "Overflowed, reducing loss scale and replaying batch.", flush=True) else: total_loss.backward() xm.optimizer_step(optimizer) xm.mark_step() break scheduler.step() if report_step: if 'loss' not in log: log['loss'] = total_loss if args.gpus: if rank == 0: pbar.set_description(format_log( log, log_formatter)) else: xm.add_step_closure(_train_update, args=(log, log_formatter)) if args.report_metrics: xm.master_print(met.metrics_report()) if rank == 0: pbar.update(1) if eval_loader is not None: model.eval() if not args.gpus: batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device) with torch.no_grad(): record = OrderedDict() for sample in batches: evaluate(model, sample, args=args, device=device, record=record, gpu=args.gpus, report=report_step) post_evaluate(record, args=args) import json print('', flush=True) print(json.dumps(record), flush=True) print('', flush=True) except Exception as _err: err = _err finally: save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt') folder = os.path.split(os.path.abspath(save_fn))[0] os.makedirs(folder, exist_ok=True) if rank == 0 and args.gpus: torch.save(model.state_dict(), save_fn) if err: raise err else: xm.save(model.state_dict(), save_fn) if err: raise err