def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None, using_native_amp=None): if optimizer_idx == 0: if self.trainer.use_tpu and XLA_AVAILABLE: xm.optimizer_step(optimizer) elif isinstance(optimizer, torch.optim.LBFGS): optimizer.step(second_order_closure) else: optimizer.step() # clear gradients optimizer.zero_grad() elif optimizer_idx == 1: pass
def train_loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = context.getattr_or( "optimizer", lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum), ) tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update( device, x, loss.item(), tracker.rate(), tracker.global_rate() )
def train_loop_fn(model, loader, device, context): loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([3,3,3,3,3,5]).to(device)) log_loss = nn.BCEWithLogitsLoss(weight=torch.FloatTensor([1,1,1,1,1,2]).to(device), reduction='none') def metric_fn(outputs, target): return (log_loss(outputs, target).sum(-1) / log_loss.weight.sum()).mean() if args.metric_loss: loss_fn = metric_fn optimizer = context.getattr_or( 'optimizer', lambda: torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay) ) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type='WarmupAndExponentialDecayScheduler', scheduler_divisor=args.slr_divisor, scheduler_divide_every_n_epochs=args.slr_div_epochs, num_warmup_epochs=args.n_warmup, min_delta_to_update_lr=args.min_lr, num_steps_per_epoch=num_steps_per_epoch)) score = MovingAverage(maxlen=500) metric = MovingAverage(maxlen=500) model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) if args.model_name == 'inception_v3': output = output.logits loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) score(loss.item()) metric(metric_fn(output, target).item()) if x % args.log_steps == 0: logging.info('[{}]({:5d}) Moving average loss: {:.5f}, metric: {:.5f}' .format(device, x, score.mean(), metric.mean())) if lr_scheduler: lr_scheduler.step()
def train_one_epoch(self, train_loader, e, save_flag): self.model.train() losses = AverageMeter() final_scores = RocAucMeter() t = time.time() for step, (targets, inputs, attention_masks, ids) in enumerate(train_loader): if self.config.verbose: if step % self.config.verbose_step == 0: self.log( f'Train Step {step}, loss: ' + \ f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \ f'time: {(time.time() - t):.5f}' ) inputs = inputs.to(self.device, dtype=torch.long) attention_masks = attention_masks.to(self.device, dtype=torch.long) targets = targets.to(self.device, dtype=torch.float) self.optimizer.zero_grad() outputs = self.model(inputs, attention_masks) loss = self.criterion(outputs, targets) batch_size = inputs.size(0) final_scores.update(targets, outputs) losses.update(loss.detach().item(), batch_size) loss.backward() xm.optimizer_step(self.optimizer) if self.config.step_scheduler: self.scheduler.step() self.model.eval() if save_flag == 1: self.save(f'{FILE_NAME}_epoch_{e}.bin') return losses, final_scores
def train_one_epoch(loader): model.train() running_loss = 0 max_idx = 0 xm.master_print("-" * 40) xm.master_print("Step\t|\tTime") xm.master_print("-" * 40) for idx, (images, targets) in enumerate(loader): optimizer.zero_grad() y_pred = model(images.float()) loss = criterion(y_pred, targets) running_loss += float(loss) loss.backward() xm.optimizer_step(optimizer) # xm.mark_step() call everystep for grad accum max_idx = float(idx) if idx % FLAGS["log_steps"] == 0 and idx != 0: xm.master_print("({})\t|\t{}".format( idx, time.asctime(time.localtime()))) xm.master_print("-" * 40) return running_loss / (max_idx + 1)
def step(self, curr): #selects the loss if self.schedule_coeff[self.i][0] < curr: self.i = self.i + 1 self.which = self.l - self.i for optimizer in self.optimizers: optimizer.zero_grad() t = time() try: data = next(self.data) except StopIteration: self.Data_Generator.reset_generator() self.data = self.Data_Generator.next_batch() data = next(self.data) self.IO_time += time() - t t = time() loss = self.Network.score(imgL=data[0], imgR=data[1], which=self.which, lp=self.i + 1, train=True) self.forward_time += time() - t t = time() self.sm += loss[0].detach() self.re += loss[2].detach() self.ds += loss[1].detach() self.em += loss[3].detach() l = 0 for i in loss: l += i l = l.mul(self.schedule_coeff[self.i][1]) l.backward() if self.tpu: for optimizer in self.optimizers: xm.optimizer_step(optimizer) #, barrier=True) else: for optimizer in self.optimizers: optimizer.step() self.backward_time += time() - t
def tpu_train_fn(data_loader, model, optimizer, device, num_batches, scheduler=None, loss_fn=None): model.train() tk0 = tqdm(data_loader, total=len(data_loader), desc="Training") for bi, d in enumerate(tk0): ids = d["ids"] token_type_ids = d["token_type_ids"] mask = d["mask"] targets_start = d["targets_start"] targets_end = d["targets_end"] sentiment = d["sentiment"] orig_selected = d["orig_selected"] orig_tweet = d["orig_tweet"] targets_start = d["targets_start"] targets_end = d["targets_end"] offsets = d["offsets"] ids = ids.to(device, dtype=torch.long) token_type_ids = token_type_ids.to(device, dtype=torch.long) mask = mask.to(device, dtype=torch.long) targets_start = targets_start.to(device, dtype=torch.long) targets_end = targets_end.to(device, dtype=torch.long) model.zero_grad() outputs_start, outputs_end = model( ids=ids, mask=mask, token_type_ids=token_type_ids, ) loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end) loss.backward() xm.optimizer_step(optimizer, barrier=True) scheduler.step() tk0.set_postfix(loss=loss.item())
def train_loop_fn(train_loader, args, model, criterion, optimizer, device, scheduler=None): model.train() criterion.train() for i, sample in enumerate(train_loader): sample = _prepare_sample(sample, device) print(sample["target"].shape, sample["target"].device) optimizer.zero_grad() _, _, logging_output = criterion(model, sample) logging = criterion.aggregate_logging_outputs([logging_output]) loss = logging["loss"] loss.backward() xm.optimizer_step(optimizer, barrier=True) if i % args.log_steps == 0: xm.master_print('bi={}, loss={:.4f}'.format(i, loss.item())) xm.master_print('MEM: {}'.format(psutil.virtual_memory())) print('End training: {}'.format(device))
def train(embedder, model, optimizer, trainloader, writer, logger, epoch, pt_dir,device): try: tracker = xm.RateTracker() criterion = nn.MSELoss() model.train() step = 0 for batch_idx, (dvec_mel, target_mag, mixed_mag) in enumerate(trainloader): target_mag, mixed_mag = target_mag.to(device), mixed_mag.to(device) dvec_list = list() for mel in dvec_mel: mel = mel.to(device) dvec = embedder(mel) dvec_list.append(dvec) dvec = torch.stack(dvec_list, dim=0) dvec = dvec.detach() #mask model optimizer.zero_grad() mask = model(mixed_mag, dvec) output = mixed_mag * mask #calculate loss, the paper says it use powerlaw, but we don't do it here loss = criterion(output, target_mag) loss.backward() xm.optimizer_step(optimizer) tracker.add(len(output)) loss = loss.item() #log step += len(output) logger.info('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), batch_idx, loss, tracker.rate(), tracker.global_rate(), time.asctime())) if step % config.train['ckpt_interval'] == 0 : model_saver(model,optimizer,pt_dir,epoch) logger.info("Saved Checkpoint at Epoch%d,Step%d" % (epoch, step)) except Exception as e: logger.info("Exiting due to exception: %s" % e) traceback.print_exc()
def train_one_epoch(self, train_loader): self.model.train() summary_loss = AverageMeter() final_scores = RocAucMeter() t = time.time() for step, (images, targets) in enumerate(train_loader): t0 = time.time() batch_size = images.shape[0] outputs = self.model(images) self.optimizer.zero_grad() loss = self.criterion(outputs, targets) loss.backward() # compute and sum gradients on params #torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=global_config.CLIP_GRAD_NORM) xm.optimizer_step(self.optimizer) if self.config.step_scheduler: self.scheduler.step() try: final_scores.update(targets, outputs) except: # xm.master_print("outputs: ", list(outputs.data.cpu().numpy())[:10]) pass summary_loss.update(loss.detach().item(), batch_size) if self.config.verbose: if step % self.config.verbose_step == 0: t1 = time.time() effNet_lr = np.format_float_scientific(self.optimizer.param_groups[0]['lr'], unique=False, precision=1) head_lr = np.format_float_scientific(self.optimizer.param_groups[1]['lr'], unique=False, precision=1) xm.master_print(f":::({str(step).rjust(4, ' ')}/{len(train_loader)}) | Loss: {summary_loss.avg:.4f} | AUC: {final_scores.avg:.5f} | LR: {effNet_lr}/{head_lr} | BTime: {t1-t0 :.2f}s | ETime: {int((t1-t0)*(len(train_loader)-step)//60)}m") return summary_loss, final_scores
def train_fn(epoch, train_dataloader, optimizer, criterion, scheduler, device): model.train() for batch_idx, batch_data in enumerate(train_dataloader): optimizer.zero_grad() batch_data = any2device(batch_data, device) outputs = model(**batch_data) y_pred = outputs[OUTPUT_PRED_MODIFICATION_TYPE] y_true = batch_data[INPUT_TRUE_MODIFICATION_TYPE] loss = criterion(y_pred, y_true) if batch_idx % 100: xm.master_print(f"Batch: {batch_idx}, loss: {loss.item()}") loss.backward() xm.optimizer_step(optimizer) if scheduler is not None: scheduler.step()
def train_loop_fn(data_loader, model, optimizer, device, scheduler=None): model.train() for bi, d in enumerate(data_loader): # bi --> batch index ids = d['ids'] maks = d['mask'] segment_ids = d['segment_ids'] targets = d['targets'] ids = ids.to(device, dtype=torch.long) mask = mask.to(device, dtype=torch.long) segment_ids = segment_ids.to(device, dtype=torch.long) targets = targets.to(device, dtype=torch.float) optimizer.zero_grad() outputs = model(ids=ids, mask=mask, token_type_ids=segment_ids) loss = loss_fn(outputs, targets) loss.backward() xm.optimizer_step( optimizer, barrier=True) # optimizer.step()'in yerine kullaniyoruz if scheduler is not None: scheduler.step() if bi % 10 == 0: print(f"bi={bi}, loss={loss}")
def train(): net.train() # enter train mode loss_avg = 0.0 for bx, by in tqdm(train_loader): # print(xmetrics.metrics_report()) bx, by = bx.to(xla_device), by.to(xla_device) curr_batch_size = bx.size(0) # forward logits = net(bx * 2 - 1) # backward optimizer.zero_grad() loss = F.cross_entropy(logits, by) loss.backward() xm.optimizer_step(optimizer, barrier=True) scheduler.step() # exponential moving average loss_avg = loss_avg * 0.9 + float(loss) * 0.1 state['train_loss'] = loss_avg
def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for data, target in loader: with xu.TimedScope(msg='Training loop: ', printfn=None): optimizer.zero_grad() output = xu.timed(lambda: model(data), msg='Model: ', printfn=None) loss = xu.timed( lambda: loss_fn(output, target), msg='Loss: ', printfn=None) xu.timed(loss.backward, msg='LossBkw: ', printfn=None) xu.timed( lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0)
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() total_samples_train, correct_train = 0, 0 # Training and calculating train accuracy and loss for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) train_loss = loss_fn(output, target) train_loss.backward() xm.optimizer_step(optimizer) tracker.add(data.shape[0]) pred_train = output.max(1, keepdim=True)[1] correct_train += pred_train.eq(target.view_as(pred_train)).sum().item() total_samples_train += data.size()[0] scheduler.step() if x % 40 == 0: print( "[xla:{}]({})\tLoss={:.3f}\tRate={:.2f}\tGlobalRate={:.2f}".format( xm.get_ordinal(), x, train_loss.item(), tracker.rate(), tracker.global_rate(), ), flush=True, ) train_accuracy = 100.0 * correct_train / total_samples_train print( "[xla:{}] Accuracy={:.2f}%".format(xm.get_ordinal(), train_accuracy), flush=True, ) return train_accuracy
def train_loop_fn(model, loader, device, context): loss_fn = nn.BCEWithLogitsLoss(reduction='mean', pos_weight=torch.FloatTensor([7 ]).to(device)) optimizer = context.getattr_or( 'optimizer', lambda: torch.optim.AdamW(model.parameters(), lr=args.lr, eps=1e-08, betas=(0.9, 0.999), weight_decay=args.weight_decay)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type='WarmupAndExponentialDecayScheduler', scheduler_divisor=args.slr_divisor, scheduler_divide_every_n_epochs=args.slr_divide_n_epochs, num_warmup_epochs=args.num_warmup_epochs, min_delta_to_update_lr=args.num_warmup_epochs, num_steps_per_epoch=num_steps_per_epoch)) score = [] model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output['out'], target) loss.backward() xm.optimizer_step(optimizer) score.append(loss.item()) if (args.log_step) and (x % args.log_step) == 0: logging.info('[{}]({}) Loss={:.4f}'.format(device, x, loss.item())) if lr_scheduler: lr_scheduler.step() score = sum(score) / len(score) return score
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, batch in enumerate(loader): # batch = tuple(t.to(self.device) for t in batch) loss = model(*batch) # the last one is label #loss = criterion(output, batch[-1]) loss.backward() # xm.optimizer_step(optimizer) # optimizer.zero_grad() tracker.add(FLAGS.batch_size) if (x + 1) % config.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_( model.parameters(), config.max_grad_norm) # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。 xm.optimizer_step(optimizer) optimizer.zero_grad() if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True)
def train_iteration(model, optimizer, dataset, train_pairs, qrels): model.train() total = 0 total_loss = 0. with tqdm('training', total=BATCH_SIZE * BATCHES_PER_EPOCH, ncols=80, desc='train') as pbar: for n_iter, record in enumerate( data.iter_train_pairs(model, dataset, train_pairs, qrels, GRAD_ACC_SIZE)): # if n_iter > 15: # return scores = model(record['query_tok'], record['query_mask'], record['doc_tok'], record['doc_mask']) count = len(record['query_id']) // 2 # scores = scores.reshape(count, 2) # loss = torch.mean(1. - scores.softmax(dim=1)[:, 0]) # pairwise softmax # loss.backward() # total_loss += loss.item() # total_loss += loss total += count # if n_iter > 0: # print(n_iter, [(record[x].size(), record[x].device) for x in ['query_tok', 'query_mask', 'doc_tok', 'doc_mask']]) # import torch_xla.debug.metrics as met # print(met.metrics_report()) if total % BATCH_SIZE == 0: xm.optimizer_step(optimizer, barrier=True) optimizer.zero_grad() pbar.update(count) if total >= BATCH_SIZE * BATCHES_PER_EPOCH: return total_loss
def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None)) tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update( device, x, loss.item(), tracker.rate(), tracker.global_rate(), summary_writer=writer) if lr_scheduler: lr_scheduler.step()
def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs): trainer = self._trainer optimizer = self._optimizer model = trainer.get_model() if trainer.on_tpu: with trainer.profiler.profile(profiler_name): xm.optimizer_step(optimizer, optimizer_args={ 'closure': closure, **kwargs }) elif trainer.amp_backend is not None: trainer.precision_connector.backend.optimizer_step( trainer, optimizer, closure) else: with trainer.profiler.profile(profiler_name): optimizer.step(closure=closure, *args, **kwargs) accelerator_backend = trainer.accelerator_backend if accelerator_backend is not None and accelerator_backend.rpc_enabled: if accelerator_backend.ddp_plugin.is_main_rpc_process: # Initialize optimizer step on main process accelerator_backend.ddp_plugin.worker_optimizer_step( model=model, opt_idx=self._optimizer_idx, *args, **kwargs) trainer.train_loop.on_before_zero_grad(self) model.optimizer_zero_grad(trainer.current_epoch, trainer.batch_idx, optimizer, self._optimizer_idx)
def train(self): bar_total = tqdm(range(self.start_epoch, self.end_epoch), desc='Training', leave=False) n_samples = len(self.train_loader.sampler) for self.epoch in bar_total: total_loss = 0 for data in self.train_loader: inputs, labels = data inputs, labels = Variable(inputs), Variable(labels) inputs, labels = inputs.to(self.device), labels.to(self.device) #inputs = inputs.transpose(1, 3) y_pred = self.model(inputs) loss = self.criterion(y_pred, labels) self.optimizer.zero_grad() loss.backward() if self.tpu: xm.optimizer_step(self.optimizer, barrier=True) else: self.optimizer.step() total_loss += loss.item() train_loss = total_loss / len(self.train_loader) bar_total.set_description("Loss: {}".format(train_loss)) bar_total.refresh() if self.epoch % self.summary_write == 0: accuracy = self.evaluate() self.summary.add_scalar('Train loss', train_loss, self.epoch) self.summary.add_scalar('Validation accuracy', accuracy, self.epoch) self.summary.close() if self.epoch % self.save_model == 0: self.save_checkpoint()
def train_loop_fn(loader): tracker = xm.RateTracker() positions = torch.arange(SEQUENCE_LENGTH).long().view( 1, SEQUENCE_LENGTH).to(device) causal_mask = torch.triu(torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.uint8, device=device), diagonal=1).unsqueeze(0) model.train() for iteration, batch in enumerate(loader): optimizer.zero_grad() input = batch[:, :-1].long() target = batch[:, 1:].long() if not xla_enabled: input = input.to(device) target = target.to(device) if amp_enabled: loss = loop_with_amp(model, input, positions, target, causal_mask, optimizer, xla_enabled, autocast, scaler) else: loss = model(input, positions, target, batch_mask=causal_mask) loss.backward() if xla_enabled: xm.optimizer_step(optimizer) else: optimizer.step() tracker.add(BATCH_SIZE) if iteration % LOG_STEPS == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format( device, iteration, loss.item() / math.log(2), tracker.rate()))
def optimizer_step(self, model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any) -> None: if isinstance(model, pl.LightningModule): closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) closure_result = xm.optimizer_step(optimizer, optimizer_args={ "closure": closure, **kwargs }) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if isinstance(model, pl.LightningModule ) and model.automatic_optimization and skipped_backward: # we lack coverage here so disable this - something to explore if there's demand raise MisconfigurationException( "Skipping backward by returning `None` from your `training_step` is not implemented for TPUs." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" " requesting this feature.")
def train(self, model_path: Optional[str] = None): """ Main training entry point. Args: model_path: (Optional) Local path to model if model to train has been instantiated from a local path If present, we will try reloading the optimizer/scheduler states from there. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs optimizer, scheduler = self.get_optimizers(num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (model_path is not None and os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)) scheduler.load_state_dict( torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model if self.args.fp16: if not is_apex_available(): raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize( model, optimizer, opt_level=self.args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, find_unused_parameters=True, ) if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size( ) else: total_train_batch_size = (self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)) logger.info("***** Running training *****") logger.info(" Num examples = %d", self.num_examples(train_dataloader)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) self.global_step = 0 self.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = self.global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = self.global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 logger.info(" Starting fine-tuning.") tr_loss = 0.0 logging_loss = 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()) for epoch in train_iterator: if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader( train_dataloader, [self.args.device]).per_device_loader(self.args.device) epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) else: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(model, inputs, optimizer) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if self.args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(optimizer) else: optimizer.step() scheduler.step() model.zero_grad() self.global_step += 1 self.epoch = epoch + (step + 1) / len(epoch_iterator) if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (self.global_step == 1 and self.args.logging_first_step): logs: Dict[str, float] = {} logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps # backward compatibility for pytorch schedulers logs["learning_rate"] = ( scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else scheduler.get_lr()[0]) logging_loss = tr_loss self._log(logs) if self.args.evaluate_during_training: self.evaluate() if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.model else: assert model is self.model # Save model checkpoint output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") self.save_model(output_dir) if self.is_world_master(): self._rotate_checkpoints() if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) elif self.is_world_master(): torch.save( optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save( scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) if self.args.max_steps > 0 and self.global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and self.global_step > self.args.max_steps: train_iterator.close() break if self.args.tpu_metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if self.tb_writer: self.tb_writer.close() logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(self.global_step, tr_loss / self.global_step)
def _train_one_epoch(self, loader): loader_time = .0 train_time = .0 curr_time = time.time() self.epoch_storage = defaultdict(list) for key in ['approx', 'target', 'loss', 'batch_metric']: self.epoch_storage[key] = [] if self.fp16: scaler = amp.GradScaler() self.model.train() if self.progress_bar and self.rank == 0: iterator = enumerate(tqdm(loader, desc='train')) else: iterator = enumerate(loader) for batch_i, inputs in iterator: loader_time += time.time() - curr_time curr_time = time.time() self.optimizer.zero_grad() batches_done = len(loader) * (self.global_epoch - 1) + batch_i inputs = [t.to(self.device) for t in inputs] # forward and backward if self.fp16: with amp.autocast(): loss, approx = self.forward_train(self, inputs) self.evaluate_batch(self, inputs, approx) # evaluation loss = loss / self.grad_accumulations scaler.scale(loss).backward() if (batch_i + 1) % self.grad_accumulations == 0: if self.sam: # first step optimizer_state = scaler._per_optimizer_states[id( self.optimizer)] scaler.unscale_(self.optimizer) if not sum(v.item() for v in optimizer_state["found_inf_per_device"]. values()): self.optimizer.first_step(zero_grad=True) optimizer_state["stage"] = 2 scaler.update() # second step with amp.autocast(): loss2, _ = self.forward_train(self, inputs) scaler.scale(loss2).backward() scaler.unscale_(self.optimizer) if not sum(v.item() for v in optimizer_state["found_inf_per_device"]. values()): self.optimizer.second_step(zero_grad=True) optimizer_state["stage"] = 2 scaler.update() else: scaler.step(self.optimizer) scaler.update() else: loss, approx = self.forward_train(self, inputs) self.evaluate_batch(self, inputs, approx) # evaluation loss = loss / self.grad_accumulations loss.backward() if (batch_i + 1) % self.grad_accumulations == 0: if self.xla: if self.sam: raise RuntimeError( 'SAM optimizer on XLA device is not available.' ) else: xm.optimizer_step(self.optimizer, barrier=True) else: if self.sam: self.optimizer.first_step(zero_grad=True) loss2, _ = self.forward_train(self, inputs) loss2.backward() self.optimizer.second_step(zero_grad=True) else: self.optimizer.step() if self.batch_scheduler: self.scheduler.step() if self.parallel == 'ddp' and self.ddp_average_loss: if self.xla: loss_batch = xm.all_gather( loss.detach().clone().view(1)).mean().item() else: loss_batch = comm.gather_tensor( loss.detach().clone().view(1)).mean().item() else: # Use loss on device: 0 loss_batch = loss.item() # logging learning_rate = [ param_group['lr'] for param_group in self.optimizer.param_groups ] logs = [('batch_train_loss', loss_batch), ('batch_train_lr', learning_rate)] if len(self.epoch_storage['batch_metric']) > 0: metric = self.epoch_storage['batch_metric'][-1] logs.append(('batch_valid_mertric', metric)) self.tb_logger.list_of_scalars_summary(logs, batches_done) self.epoch_storage['loss'].append(loss_batch) train_time += time.time() - curr_time curr_time = time.time() if self.debug and self.rank == 0: self.logger( f'loader: {loader_time:.1f} s | train: {train_time:.1f} s') for key, val in self.epoch_storage.items(): if len(val) > 0: if isinstance(val[0], torch.Tensor): self.epoch_storage[key] = torch.cat(val) else: self.epoch_storage[key] = torch.tensor(val).to(self.device) loss_total = self.epoch_storage['loss'].mean().item() if self.parallel == 'ddp': # gather tensors for key, val in self.epoch_storage.items(): if len(val) > 0: if self.xla: self.epoch_storage[key] = xm.all_gather(val) else: self.epoch_storage[key] = comm.gather_tensor(val) metric_total, monitor_metrics_total = self.evaluate_epoch(self) else: metric_total, monitor_metrics_total = self.evaluate_epoch(self) if metric_total is None: metric_total = loss_total # logging logs = [ ('epoch_train_loss', loss_total), ('epoch_train_metric', metric_total), ] self.tb_logger.list_of_scalars_summary(logs, self.global_epoch) return loss_total, metric_total, monitor_metrics_total
def train(epoch): logger.info('\nEpoch: %d' % epoch) net.train() train_loss = AverageMeter(100) acc = AverageMeter(100) batch_time = AverageMeter() reg_loss = AverageMeter(100) train_loss_avg = 0 correct = 0 total = 0 mean = 0 var = 0 lambda_ = 0 xi_ = 0 for m in net.modules(): if isinstance(m, Constraint_Norm): m.reset_norm_statistics() for batch_idx, (inputs, targets) in enumerate(trainloader): start = time.time() if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() else: inputs = inputs.to(device) targets = targets.to(device) bsz = inputs.size(0) outputs = net(inputs) if args.optim_loss == 'mse': targets = targets.float() loss = criterion(outputs, targets) # constraint loss weight_mean = 0 weight_var = 0 weight_mean_abs = 0 weight_var_abs = 0 for m in net.modules(): if isinstance(m, Constraint_Lagrangian): weight_mean_, weight_var_ = m.get_weight_mean_var() weight_mean_abs_, weight_var_abs_ = m.get_weight_mean_var_abs() weight_mean += weight_mean_ weight_var += weight_var_ weight_mean_abs += weight_mean_abs_ weight_var_abs += weight_var_abs_ constraint_loss = args.lambda_weight_mean * weight_mean + weight_var constraint_loss = args.lambda_constraint_weight * constraint_loss weight_mean_abs = args.lambda_constraint_weight * weight_mean_abs weight_var_abs = args.lambda_constraint_weight * weight_var_abs # optimize constraint loss train_loss.update(loss.item()) train_loss_avg += loss.item() loss += constraint_loss # optimize _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct_idx = predicted.eq(targets.data).cpu().sum().float() correct += correct_idx acc.update(100. * correct_idx / float(targets.size(0))) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip) if use_cuda: optimizer.step() else: xm.optimizer_step(optimizer, barrier=True) batch_time.update(time.time() - start) remain_iter = args.epoch * len(trainloader) - ( epoch * len(trainloader) + batch_idx) remain_time = remain_iter * batch_time.avg t_m, t_s = divmod(remain_time, 60) t_h, t_m = divmod(t_m, 60) remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) if (batch_idx + 1) % args.print_freq == 0: logger.info('Train: [{0}][{1}/{2}]\t' 'Loss {train_loss.avg:.3f}\t' 'acc {acc.avg:.3f}\t' 'correct: [{correct}/{total}]\t' 'Constraint mean {corat_mean:.4f}\t' 'Constraint var {corat_var:.4f}\t' 'Constraint lambda {corat_lambda:.4f}\t' 'Constraint xi {corat_xi:.4f}\t' 'mean {mean:.4f}\t' 'var {var:.4f}\t' 'remain_time: {remain_time}'.format( epoch, batch_idx, len(trainloader), train_loss=train_loss, corat_mean=-1 * weight_mean.item(), corat_var=-1 * weight_var.item(), corat_lambda=lambda_, corat_xi=xi_, mean=mean, var=var, acc=acc, correct=int(correct), total=total, remain_time=remain_time, )) if (batch_idx + 1) % args.print_freq == 0: mean = [] var = [] for m in net.modules(): if isinstance(m, Constraint_Norm): mean_, var_ = m.get_mean_var() mean.append(mean_.abs()) var.append(var_.abs()) mean = torch.mean(torch.stack(mean)) var = torch.mean(torch.stack(var)) curr_idx = epoch * len(trainloader) + batch_idx tb_logger.add_scalar("train/train_loss", train_loss.avg, curr_idx) tb_logger.add_scalar("train/train_acc", acc.avg, curr_idx) tb_logger.add_scalar("train/norm_mean(abs)", mean, curr_idx) tb_logger.add_scalar("train/norm_var-1(abs)", var, curr_idx) tb_logger.add_scalar("train/weight_mean(abs)", weight_mean_abs.item(), curr_idx) tb_logger.add_scalar("train/weight_var-1(abs)", weight_var_abs.item(), curr_idx) tb_logger.add_scalar("train/constraint_loss_mean", -1 * weight_mean.item(), curr_idx) tb_logger.add_scalar("train/constraint_loss_var", -1 * weight_var.item(), curr_idx) # get the constraint weight lambda_ = [] xi_ = [] for m in net.modules(): if isinstance(m, Constraint_Lagrangian): lambda_.append(m.lambda_.data.abs().mean()) xi_.append(m.xi_.data.abs().mean()) lambda_ = torch.max(torch.stack(lambda_)) xi_ = torch.max(torch.stack(xi_)) tb_logger.add_scalar("train/constraint_lambda_", lambda_.item(), curr_idx) tb_logger.add_scalar("train/constraint_xi_", xi_.item(), curr_idx) tb_logger.add_scalar("train/train_loss_epoch", train_loss_avg / len(trainloader), epoch) tb_logger.add_scalar("train/train_acc_epoch", 100. * correct / total, epoch) wandb.log({"train/acc_epoch": 100. * correct / total}, step=epoch) wandb.log({"train/loss_epoch": train_loss_avg / len(trainloader)}, step=epoch) wandb.log({"train/norm_mean(abs)": mean.item()}, step=epoch) wandb.log({"train/norm_var-1(abs)": var.item()}, step=epoch) wandb.log({"train/weight_mean(abs)": weight_mean_abs.item()}, step=epoch) wandb.log({"train/weight_var-1(abs)": weight_var_abs.item()}, step=epoch) wandb.log({"train/constraint_loss_mean": -1 * weight_mean.item()}, step=epoch) wandb.log({"train/constraint_loss_var": -1 * weight_var.item()}, step=epoch) logger.info("epoch: {} acc: {}, loss: {}".format( epoch, 100. * correct / total, train_loss_avg / len(trainloader))) for m in net.modules(): if isinstance(m, Constraint_Norm): m.reset_norm_statistics() return (train_loss.avg, reg_loss.avg, 100. * correct / total)
def step(self, closure: Optional[Callable] = None) -> None: xm.optimizer_step(self.wrapped_optimizer, barrier=True)
def train_model(model, criterion, optimizer, dataloaders, dataset_sizes, num_epochs=10, model_type="VS", weight_file="best_modelweights.dat", L1_loss=0, suppress_log=False, hyperparam_search=False, use_tpu=False, multigpu=False, tensorboard=True): if use_tpu: print( "using TPU acceleration, model and optimizer should already be loaded onto tpu device" ) device = xm.xla_device() else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): print("using GPU acceleration") if multigpu and torch.cuda.device_count() > 1: print("multigpu enabled") model = nn.DataParallel(model) model = model.to(device, dtype=torch.float) else: model = model.to(device, dtype=torch.float) since = time.time() best_loss = np.Inf #train_losses = np.zeros(num_epochs*dataset_sizes['train']) #val_losses = np.zeros(num_epochs*dataset_sizes['val']) train_losses = np.zeros(num_epochs * len(dataloaders['train'])) val_losses = np.zeros(num_epochs * len(dataloaders['val'])) it_val = 0 it_train = 0 if tensorboard: writer = SummaryWriter() for epoch in range(num_epochs): if suppress_log == False: print('Epoch {}/{}'.format(epoch + 1, num_epochs)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode # initialize the predictions if dataloaders[phase].dataset.include_torque: predictions = np.empty((0, 6)) else: predictions = np.empty((0, 3)) running_loss = 0.0 # Iterate over data. batch_size = 0 it = 1 for inputs, aug_inputs, labels in dataloaders[phase]: # zero the parameter gradients optimizer.zero_grad() if model_type != "S": inputs = inputs.to(device, dtype=torch.float) if (model_type != "V") or (model_type != "V_RNN"): aug_inputs = aug_inputs.to(device, dtype=torch.float) labels = labels.to(device, dtype=torch.float) # forward # track history if only in train if phase == 'train': torch.set_grad_enabled(True) if (model_type == "V") or (model_type == "V_RNN"): outputs = model(inputs) elif model_type == "VS": outputs = model(inputs, aug_inputs) else: outputs = model(aug_inputs) loss = criterion(outputs, labels) if L1_loss: L1 = 0 for param in model.parameters(): if param.requires_grad: L1 += L1_loss * torch.sum(torch.abs(param)) loss = loss + L1 if multigpu: loss.mean().backward() else: loss.backward() if use_tpu: xm.optimizer_step(optimizer, barrier=True) else: optimizer.step() else: torch.set_grad_enabled(False) if (model_type == "V") or (model_type == "V_RNN"): outputs = model(inputs) elif model_type == "VS": outputs = model(inputs, aug_inputs) else: outputs = model(aug_inputs) loss = criterion(outputs, labels) predictions = np.vstack( (predictions, outputs.cpu().detach().numpy())) # statistics running_loss += loss.item( ) #* inputs.size(0) # multiply by the number of elements to get back the total loss, usually the loss function outputs the mean batch_size += inputs.size(0) avg_loss = running_loss / batch_size if phase == 'train': train_losses[it_train] = avg_loss if tensorboard: writer.add_scalar('Loss/train', avg_loss, it_train) it_train += 1 else: val_losses[it_val] = avg_loss if tensorboard: writer.add_scalar('Loss/val', avg_loss, it_val) it_val += 1 if it % 100 == 0 and suppress_log == False: print('average loss for batch ' + str(it) + ' : ' + str(avg_loss)) it += 1 epoch_loss = running_loss / dataset_sizes[ phase] #divide by the total size of our dataset to get the mean loss per instance if tensorboard: if phase == "train": writer.add_scalar('ELoss/train', epoch_loss, epoch) if phase == "val": writer.add_scalar('ELoss/val', epoch_loss, epoch) if suppress_log == False: print('{} Loss: {:.4f}'.format(phase, epoch_loss)) # deep copy the model if phase == 'val' and epoch_loss < best_loss: if hyperparam_search == False: print('Saving model... current loss:' + str(round(epoch_loss, 5)) + ' < best loss: ' + str(round(best_loss, 5))) print("Backing up the model") temp_file = open(weight_file, "wb") torch.save(model.state_dict(), temp_file) if tensorboard: fig, ax = plt.subplots(3, 1, sharex=True, figsize=(50, 10)) plt.ioff() for f_ax in range(3): ax[f_ax].plot( dataloaders[phase].dataset.label_array[:, f_ax + 1]) ax[f_ax].plot(predictions[:, f_ax], linewidth=1) writer.add_figure('valPred/figure', fig, global_step=epoch, close=True) else: print('current loss:' + str(round(epoch_loss, 5)) + ' < best loss: ' + str(round(best_loss, 5))) best_loss = epoch_loss if suppress_log == False: time_elapsed = time.time() - since print('Epoch runtime {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val loss: {:4f}'.format(best_loss)) # load best model weights if hyperparam_search == False: temp_file.close() temp_file = open(weight_file, "rb") model.load_state_dict(torch.load(temp_file)) return model, train_losses, val_losses, best_loss
def train(args, train_dataset, model, tokenizer, disable_logging=False): """ Train the model """ if xm.is_master_ordinal(): # Only master writes to Tensorboard tb_writer = SummaryWriter(args.tensorboard_logdir) train_sampler = get_sampler(train_dataset) dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // (len(dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len(dataloader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(dataloader) * args.train_batch_size) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per TPU core = %d", args.train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", (args.train_batch_size * args.gradient_accumulation_steps * xm.xrt_world_size()), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 loss = None model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=disable_logging) set_seed(args.seed) # Added here for reproductibility (even between python 2 and 3) for epoch in train_iterator: # tpu-comment: Get TPU parallel loader which sends data to TPU in background. train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging) for step, batch in enumerate(epoch_iterator): # Save model checkpoint. if args.save_steps > 0 and global_step % args.save_steps == 0: output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) logger.info("Saving model checkpoint to %s", output_dir) if xm.is_master_ordinal(): if not os.path.exists(output_dir): os.makedirs(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) # Barrier to wait for saving checkpoint. xm.rendezvous("mid_training_checkpoint") # model.save_pretrained needs to be called by all ordinals model.save_pretrained(output_dir) model.train() inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": # XLM, DistilBERT and RoBERTa don't use segment_ids inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in transformers (see doc) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) xm.optimizer_step(optimizer) scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics. results = {} if args.evaluate_during_training: results = evaluate(args, model, tokenizer, disable_logging=disable_logging) loss_scalar = loss.item() logger.info( "global_step: {global_step}, lr: {lr:.6f}, loss: {loss:.3f}".format( global_step=global_step, lr=scheduler.get_lr()[0], loss=loss_scalar ) ) if xm.is_master_ordinal(): # tpu-comment: All values must be in CPU and not on TPU device for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", loss_scalar, global_step) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if xm.is_master_ordinal(): tb_writer.close() return global_step, loss.item()
def train(self, data_loader): losses = AverageMeter() self.model.train() print_idx = int(len(data_loader) * self.tpu_print / 100) if self.accumulation_steps > 1: self.optimizer.zero_grad() if self.use_tpu: para_loader = pl.ParallelLoader(data_loader, [self.device]) tk0 = para_loader.per_device_loader(self.device) else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): if self.accumulation_steps == 1 and b_idx == 0: self.optimizer.zero_grad() if self.model_fn is None: for key, value in data.items(): data[key] = value.to(self.device) _, loss = self.model(**data) else: if self.fp16: with amp.autocast(): loss = self.model_fn(data, self.device, self.model) else: loss = self.model_fn(data, self.device, self.model) if not self.use_tpu: with torch.set_grad_enabled(True): if self.use_mean_loss: loss = loss.mean() if self.fp16: self.scaler.scale(loss).backward() else: loss.backward() if (b_idx + 1) % self.accumulation_steps == 0: if self.fp16: self.scaler.step(self.optimizer) else: self.optimizer.step() if self.scheduler is not None: self.scheduler.step() if b_idx > 0: self.optimizer.zero_grad() if self.fp16: self.scaler.update() else: loss.backward() xm.optimizer_step(self.optimizer) if self.scheduler is not None: self.scheduler.step() if b_idx > 0: self.optimizer.zero_grad() if self.use_tpu: reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn) losses.update(reduced_loss.item(), data_loader.batch_size) else: losses.update(loss.item(), data_loader.batch_size) if not self.use_tpu: tk0.set_postfix(loss=losses.avg) else: if b_idx % print_idx == 0 or b_idx == len(data_loader): xm.master_print( f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}" ) if not self.use_tpu: tk0.close() return losses.avg