def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None): eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None try: output = self.prediction_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, ) finally: self.compute_metrics = compute_metrics if self.post_process_function is not None and self.compute_metrics is not None: eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) metrics = self.compute_metrics(eval_preds) self.log(metrics) else: metrics = {} if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate( self.args, self.state, self.control, metrics) return metrics
def _mp_fn(rank, args): print("rank", rank) device = xm.xla_device() # devices = ( # xm.get_xla_supported_devices( # max_devices=args.num_cores) if args.num_cores != 0 else []) # with _LOAD_LOCK: # _MODEL.to(device) xm.master_print('done loading model') criterion = LabelSmoothedLengthGan_CrossEntropyCriterion( args, translation_self.tgt_dict) params = list(filter(lambda p: p.requires_grad, _MODEL.parameters())) optimizer = FairseqAdam(args, params) lr_scheduler = InverseSquareRootSchedule(args, optimizer) for epoch in range(args.num_epochs): # train_loop_fn(args, _MODEL, criterion, optimizer, device) # valid_log = eval_loop_fn(args, _MODEL, criterion, device) para_loader = pl.ParallelLoader(valid_dataloader, [device]) train_loop_fn(para_loader.per_device_loader(device), args, _MODEL, criterion, optimizer, device) para_loader = pl.ParallelLoader(valid_dataloader, [device]) valid_log = eval_loop_fn(para_loader.per_device_loader(device), args, _MODEL, criterion, device) xm.master_print('Finished training epoch {}'.format(epoch)) xm.master_print( "Epoch {}, loss {:.4f}, nll_loss {:.4f}, length_loss {:.4f}, dis_loss {:.4f}" .format(epoch, valid_log["loss"], valid_log["nll_loss"], valid_log["length_loss"], valid_log["dis_loss"])) lr_scheduler.step(epoch) if args.checkpoint_path: xm.save(_MODEL.state_dict(), args.checkpoint_path)
def evaluate(self, data_loader, return_predictions=False): losses = AverageMeter() print_idx = int(len(data_loader) * self.tpu_print / 100) self.model.eval() final_predictions = [] with torch.no_grad(): if self.use_tpu: para_loader = pl.ParallelLoader(data_loader, [self.device]) tk0 = para_loader.per_device_loader(self.device) else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): for key, value in data.items(): data[key] = value.to(self.device) if self.fp16: with amp.autocast(): batch_preds, loss = self.model(**data) else: batch_preds, loss = self.model(**data) if return_predictions: final_predictions.append(batch_preds) if self.use_tpu: reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn) losses.update(reduced_loss.item(), data_loader.batch_size) else: if self.use_mean_loss: loss = loss.mean() losses.update(loss.item(), data_loader.batch_size) if not self.use_tpu: tk0.set_postfix(loss=losses.avg) else: if b_idx % print_idx == 0 or b_idx == len(data_loader): xm.master_print( f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}" ) if not self.use_tpu: tk0.close() return losses.avg, final_predictions
def __init__(self, device, config, steps): self.config = config self.epoch = 0 self.steps = steps self.base_dir = './checkpoints' self.log_path = f'{self.base_dir}/log.txt' self.best_summary_loss = 10**5 # get pretrained models self.model = Customized_ENSModel(global_config.EfficientNet_Level) xm.master_print(">>> Model loaded!") self.device = device # self.model = self.model.to(device) if global_config.LOSS_FN_LabelSmoothing: self.criterion = LabelSmoothing() else: class_weights = torch.FloatTensor(global_config.CLASS_WEIGHTS) self.criterion = torch.nn.CrossEntropyLoss(weight=class_weights) print(f">>> Class Weights: {global_config.CLASS_WEIGHTS}") self.log(f'>>> Model is loaded. Main Device is {self.device}')
def sync_bn1d_no_channel_test(index): torch.manual_seed(1) bsz = 32 length = 64 t_global = torch.rand((xm.xrt_world_size() * bsz, length)) # XLA SyncBatchNorm device = xm.xla_device() t_xla = t_global[bsz * index:bsz * (index + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(length).to(device) result = run_step(sbn_xla, t_xla) # CPU BatchNorm bn_cpu = torch.nn.BatchNorm1d(length) expected = run_step(bn_cpu, t_global) cpu_result = result.cpu() assert cpu_result.allclose(expected, rtol=RTOL, atol=ATOL) assert_stats(sbn_xla.cpu(), bn_cpu) xm.rendezvous('sync_bn1d_no_channel_test') xm.master_print('sync_bn1d_no_channel_test ok')
def train_loop_fn(train_loader, args, model, criterion, optimizer, device, scheduler=None): model.train() criterion.train() for i, sample in enumerate(train_loader): sample = _prepare_sample(sample, device) print(sample["target"].shape, sample["target"].device) optimizer.zero_grad() _, _, logging_output = criterion(model, sample) logging = criterion.aggregate_logging_outputs([logging_output]) loss = logging["loss"] loss.backward() xm.optimizer_step(optimizer, barrier=True) if i % args.log_steps == 0: xm.master_print('bi={}, loss={:.4f}'.format(i, loss.item())) xm.master_print('MEM: {}'.format(psutil.virtual_memory())) print('End training: {}'.format(device))
def train_model_xla(net, batch_size, lr, num_epochs, log_steps=20, metrics_debug=False): torch.manual_seed(1) train_loader, test_loader = load_cifar_10_xla(batch_size) # Scale learning rate to num cores lr = lr * xm.xrt_world_size() # Get loss function, optimizer, and model device = xm.xla_device() net = net.to(device) optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) loss_fn = nn.CrossEntropyLoss() # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(1, num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device), net, optimizer, loss_fn, batch_size, log_steps) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy, data, pred, target = test_loop_fn( para_loader.per_device_loader(device), net) if metrics_debug: xm.master_print(met.metrics_report(), flush=True) return accuracy, data, pred, target
def train_one_epoch(self, train_loader): self.model.train() summary_loss = AverageMeter() final_scores = RocAucMeter() t = time.time() for step, (images, targets) in enumerate(train_loader): t0 = time.time() batch_size = images.shape[0] outputs = self.model(images) self.optimizer.zero_grad() loss = self.criterion(outputs, targets) loss.backward() # compute and sum gradients on params #torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=global_config.CLIP_GRAD_NORM) xm.optimizer_step(self.optimizer) if self.config.step_scheduler: self.scheduler.step() try: final_scores.update(targets, outputs) except: # xm.master_print("outputs: ", list(outputs.data.cpu().numpy())[:10]) pass summary_loss.update(loss.detach().item(), batch_size) if self.config.verbose: if step % self.config.verbose_step == 0: t1 = time.time() effNet_lr = np.format_float_scientific(self.optimizer.param_groups[0]['lr'], unique=False, precision=1) head_lr = np.format_float_scientific(self.optimizer.param_groups[1]['lr'], unique=False, precision=1) xm.master_print(f":::({str(step).rjust(4, ' ')}/{len(train_loader)}) | Loss: {summary_loss.avg:.4f} | AUC: {final_scores.avg:.5f} | LR: {effNet_lr}/{head_lr} | BTime: {t1-t0 :.2f}s | ETime: {int((t1-t0)*(len(train_loader)-step)//60)}m") return summary_loss, final_scores
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init :obj:`compute_metrics` argument). Args: eval_dataset (:obj:`Dataset`, `optional`): Pass a dataset if you wish to override :obj:`self.eval_dataset`. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. """ eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self._prediction_loop(eval_dataloader, description="Evaluation") self._log(output.metrics) logger.info("Evaluate results") logger.info(output.metrics) if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) return output.metrics
def __call__(self, epoch_score, model, model_path): if self.mode == "min": score = -1.0 * epoch_score else: score = np.copy(epoch_score) if self.best_score is None: self.best_score = score self.save_checkpoint(epoch_score, model, model_path) elif score < self.best_score + self.delta: self.counter += 1 if self.tpu: xm.master_print("EarlyStopping counter: {} out of {}".format( self.counter, self.patience)) else: print("EarlyStopping counter: {} out of {}".format( self.counter, self.patience)) if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(epoch_score, model, model_path) self.counter = 0
def tpu_evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): model.eval() criterion.eval() cnt = 0 total = len(data_loader) for samples, targets in data_loader: print('test') samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(samples) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict loss = loss_dict['loss_giou'] xm.master_print('Number: {}/{}, Loss:{}'.format( cnt + 1, total, loss.item())) cnt += 1 xm.master_print(met.metrics_report()) exit()
def __init__(self, model, device, config): if not os.path.exists('node_submissions'): os.makedirs('node_submissions') self.config = config self.epoch = 0 self.log_path = 'log.txt' self.model = model self.device = device param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0008}, #default 0.001 {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] self.optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr*xm.xrt_world_size()) self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params) self.criterion = config.criterion xm.master_print(f'Fitter prepared. Device is {self.device}')
def train_fn(epoch, train_dataloader, optimizer, criterion, scheduler, device): model.train() for batch_idx, batch_data in enumerate(train_dataloader): optimizer.zero_grad() batch_data = any2device(batch_data, device) outputs = model(**batch_data) y_pred = outputs[OUTPUT_PRED_MODIFICATION_TYPE] y_true = batch_data[INPUT_TRUE_MODIFICATION_TYPE] loss = criterion(y_pred, y_true) if batch_idx % 100: xm.master_print(f"Batch: {batch_idx}, loss: {loss.item()}") loss.backward() xm.optimizer_step(optimizer) if scheduler is not None: scheduler.step()
def sync_bn3d_test(index): torch.manual_seed(1) bsz = 16 features = 32 d, h, w = 16, 32, 32 t_global = torch.rand((xm.xrt_world_size() * bsz, features, d, h, w)) # XLA SyncBatchNorm device = xm.xla_device() t_xla = t_global[bsz * index:bsz * (index + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) # CPU BatchNorm bn_cpu = torch.nn.BatchNorm3d(features) expected = run_step(bn_cpu, t_global) cpu_result = result.cpu() assert cpu_result.allclose(expected, rtol=RTOL, atol=ATOL) assert_stats(sbn_xla.cpu(), bn_cpu) xm.rendezvous('sync_bn3d_test') xm.master_print('sync_bn3d_test ok')
def valid_fn(epoch, valid_dataloader, criterion, device): model.eval() pred_scores = [] true_scores = [] for batch_idx, batch_data in enumerate(valid_dataloader): batch_data = any2device(batch_data, device) outputs = model(**batch_data) y_pred = outputs[OUTPUT_PRED_MODIFICATION_TYPE] y_true = batch_data[INPUT_TRUE_MODIFICATION_TYPE] loss = criterion(y_pred, y_true) pred_scores.extend(to_numpy(parse_classifier_probas(y_pred))) true_scores.extend(to_numpy(y_true)) xm.master_print(f"Batch: {batch_idx}, loss: {loss.item()}") val_wauc = alaska_weighted_auc(xla_all_gather(true_scores, device), xla_all_gather(pred_scores, device)) xm.master_print(f"Valid epoch: {epoch}, wAUC: {val_wauc}") return val_wauc
def prepare_task(args, xla_device): # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build models and criteria to print some metadata torch.manual_seed(args.seed) model, criterion = task.build_model(args), task.build_criterion(args) xm.master_print(model) xm.master_print('| model {}, criterion {}'.format( args.arch, criterion.__class__.__name__)) xm.master_print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad))) model = model.to(xla_device) trainer = Trainer(args, task, model, criterion, xla_device=xla_device) lr = trainer.get_lr() # Load the latest checkpoint if one is available and restore the # corresponding train iterator # we overwrite distributed args here to shard data using torch_xla's # distributed training. trainer.args.distributed_rank = xm.get_ordinal() trainer.args.distributed_world_size = xm.xrt_world_size() extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) trainer.args.distributed_rank = 0 trainer.args.distributed_world_size = 1 trainer.meters_to_device(xla_device) valid_subsets = args.valid_subset.split(',') ordinal = xm.get_ordinal(defval=-1) device_str = ( str(xla_device) if ordinal < 0 else '{}/{}'.format(xla_device, ordinal) ) return task, trainer, model, epoch_itr, lr, valid_subsets, device_str
def train(self, model_path: Optional[str] = None): """ Main training entry point. Args: model_path: (Optional) Local path to model if model to train has been instantiated from a local path If present, we will try reloading the optimizer/scheduler states from there. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs optimizer, scheduler = self.get_optimizers(num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (model_path is not None and os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)) scheduler.load_state_dict( torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model if self.args.fp16: if not is_apex_available(): raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize( model, optimizer, opt_level=self.args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, find_unused_parameters=True, ) if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size( ) else: total_train_batch_size = (self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)) logger.info("***** Running training *****") logger.info(" Num examples = %d", self.num_examples(train_dataloader)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) self.global_step = 0 self.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = self.global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = self.global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 logger.info(" Starting fine-tuning.") tr_loss = 0.0 logging_loss = 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()) for epoch in train_iterator: if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader( train_dataloader, [self.args.device]).per_device_loader(self.args.device) epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) else: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(model, inputs, optimizer) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if self.args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(optimizer) else: optimizer.step() scheduler.step() model.zero_grad() self.global_step += 1 self.epoch = epoch + (step + 1) / len(epoch_iterator) if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (self.global_step == 1 and self.args.logging_first_step): logs: Dict[str, float] = {} logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps # backward compatibility for pytorch schedulers logs["learning_rate"] = ( scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else scheduler.get_lr()[0]) logging_loss = tr_loss self._log(logs) if self.args.evaluate_during_training: self.evaluate() if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.model else: assert model is self.model # Save model checkpoint output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") self.save_model(output_dir) if self.is_world_master(): self._rotate_checkpoints() if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) elif self.is_world_master(): torch.save( optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save( scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) if self.args.max_steps > 0 and self.global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and self.global_step > self.args.max_steps: train_iterator.close() break if self.args.tpu_metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if self.tb_writer: self.tb_writer.close() logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(self.global_step, tr_loss / self.global_step)
def train_mnist(flags, state_dict): if flags.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=60000 // flags.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=10000 // flags.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST() model.load_state_dict(state_dict) model = model.to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer), run_async=FLAGS.async_closures) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples # accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train(args, train_dataset, model, tokenizer, disable_logging=False): """ Train the model """ if xm.is_master_ordinal(): # Only master writes to Tensorboard tb_writer = SummaryWriter(args.tensorboard_logdir) train_sampler = get_sampler(train_dataset) dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // (len(dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len(dataloader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(dataloader) * args.train_batch_size) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per TPU core = %d", args.train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", (args.train_batch_size * args.gradient_accumulation_steps * xm.xrt_world_size()), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 loss = None model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=disable_logging) set_seed(args.seed) # Added here for reproductibility (even between python 2 and 3) for epoch in train_iterator: # tpu-comment: Get TPU parallel loader which sends data to TPU in background. train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging) for step, batch in enumerate(epoch_iterator): # Save model checkpoint. if args.save_steps > 0 and global_step % args.save_steps == 0: output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) logger.info("Saving model checkpoint to %s", output_dir) if xm.is_master_ordinal(): if not os.path.exists(output_dir): os.makedirs(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) # Barrier to wait for saving checkpoint. xm.rendezvous("mid_training_checkpoint") # model.save_pretrained needs to be called by all ordinals model.save_pretrained(output_dir) model.train() inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": # XLM, DistilBERT and RoBERTa don't use segment_ids inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in transformers (see doc) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) xm.optimizer_step(optimizer) scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics. results = {} if args.evaluate_during_training: results = evaluate(args, model, tokenizer, disable_logging=disable_logging) loss_scalar = loss.item() logger.info( "global_step: {global_step}, lr: {lr:.6f}, loss: {loss:.3f}".format( global_step=global_step, lr=scheduler.get_lr()[0], loss=loss_scalar ) ) if xm.is_master_ordinal(): # tpu-comment: All values must be in CPU and not on TPU device for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", loss_scalar, global_step) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if xm.is_master_ordinal(): tb_writer.close() return global_step, loss.item()
def evaluate(args, model, tokenizer, prefix="", disable_logging=False): """Evaluate the model""" if xm.is_master_ordinal(): # Only master writes to Tensorboard tb_writer = SummaryWriter(args.tensorboard_logdir) # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,) results = {} for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) eval_sampler = get_sampler(eval_dataset) if not os.path.exists(eval_output_dir): os.makedirs(eval_output_dir) dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False) eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(dataloader) * args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None for batch in tqdm(eval_dataloader, desc="Evaluating", disable=disable_logging): model.eval() with torch.no_grad(): inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": # XLM, DistilBERT and RoBERTa don't use segment_ids inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None outputs = model(**inputs) batch_eval_loss, logits = outputs[:2] eval_loss += batch_eval_loss nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs["labels"].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) # tpu-comment: Get all predictions and labels from all worker shards of eval dataset preds = xm.mesh_reduce("eval_preds", preds, np.concatenate) out_label_ids = xm.mesh_reduce("eval_out_label_ids", out_label_ids, np.concatenate) eval_loss = eval_loss / nb_eval_steps if args.output_mode == "classification": preds = np.argmax(preds, axis=1) elif args.output_mode == "regression": preds = np.squeeze(preds) result = compute_metrics(eval_task, preds, out_label_ids) results.update(result) results["eval_loss"] = eval_loss.item() output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") if xm.is_master_ordinal(): with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(results.keys()): logger.info(" %s = %s", key, str(results[key])) writer.write("%s = %s\n" % (key, str(results[key]))) tb_writer.add_scalar(f"{eval_task}/{key}", results[key]) if args.metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if xm.is_master_ordinal(): tb_writer.close() return results
def train(self, data_loader): losses = AverageMeter() self.model.train() print_idx = int(len(data_loader) * self.tpu_print / 100) if self.accumulation_steps > 1: self.optimizer.zero_grad() if self.use_tpu: para_loader = pl.ParallelLoader(data_loader, [self.device]) tk0 = para_loader.per_device_loader(self.device) else: tk0 = tqdm(data_loader, total=len(data_loader)) for b_idx, data in enumerate(tk0): if self.accumulation_steps == 1 and b_idx == 0: self.optimizer.zero_grad() if self.model_fn is None: for key, value in data.items(): data[key] = value.to(self.device) _, loss = self.model(**data) else: if self.fp16: with amp.autocast(): loss = self.model_fn(data, self.device, self.model) else: loss = self.model_fn(data, self.device, self.model) if not self.use_tpu: with torch.set_grad_enabled(True): if self.use_mean_loss: loss = loss.mean() if self.fp16: self.scaler.scale(loss).backward() else: loss.backward() if (b_idx + 1) % self.accumulation_steps == 0: if self.fp16: self.scaler.step(self.optimizer) else: self.optimizer.step() if self.scheduler is not None: self.scheduler.step() if b_idx > 0: self.optimizer.zero_grad() if self.fp16: self.scaler.update() else: loss.backward() xm.optimizer_step(self.optimizer) if self.scheduler is not None: self.scheduler.step() if b_idx > 0: self.optimizer.zero_grad() if self.use_tpu: reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn) losses.update(reduced_loss.item(), data_loader.batch_size) else: losses.update(loss.item(), data_loader.batch_size) if not self.use_tpu: tk0.set_postfix(loss=losses.avg) else: if b_idx % print_idx == 0 or b_idx == len(data_loader): xm.master_print( f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}" ) if not self.use_tpu: tk0.close() return losses.avg
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an # optimizer step D.optim.zero_grad() for accumulation_index in range(config['num_D_accumulations']): z_, y_ = sample() D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D']) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) D_loss = (D_loss_real + D_loss_fake) / \ float(config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. xm.master_print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) xm.optimizer_step(D.optim) # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() # If accumulating gradients, loop multiple times for accumulation_index in range(config['num_G_accumulations']): z_, y_ = sample() D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) G_loss = losses.generator_loss( D_fake) / float(config['num_G_accumulations']) G_loss.backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in G print('using modified ortho reg in G') # Don't ortho reg shared, it makes no sense. Really we should # blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) xm.optimizer_step(G.optim) # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = {'G_loss': G_loss, 'D_loss_real': D_loss_real, 'D_loss_fake': D_loss_fake} # Return G's loss and the components of D's loss. return out
def run(index): MAX_LEN = 512 TRAIN_BATCH_SIZE = 16 EPOCHS = 50 dfx = pd.read_csv("/home/nizamphoenix/dataset/train.csv").fillna("none") df_train, df_valid = model_selection.train_test_split(dfx, random_state=42, test_size=0.3) df_train = df_train.reset_index(drop=True) df_valid = df_valid.reset_index(drop=True) sample = pd.read_csv("/home/nizamphoenix/dataset/sample_submission.csv") target_cols = list(sample.drop("qa_id", axis=1).columns) train_targets = df_train[target_cols].values valid_targets = df_valid[target_cols].values tokenizer = transformers.BertTokenizer.from_pretrained( "/home/nizamphoenix/bert-base-uncased/") train_dataset = BERTDatasetTraining(qtitle=df_train.question_title.values, qbody=df_train.question_body.values, answer=df_train.answer.values, targets=train_targets, tokenizer=tokenizer, max_len=MAX_LEN) train_sampler = torch.utils.data.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=TRAIN_BATCH_SIZE, sampler=train_sampler) valid_dataset = BERTDatasetTraining(qtitle=df_valid.question_title.values, qbody=df_valid.question_body.values, answer=df_valid.answer.values, targets=valid_targets, tokenizer=tokenizer, max_len=MAX_LEN) valid_sampler = torch.utils.data.DistributedSampler( valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), ) valid_data_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=8, #can make changes here sampler=valid_sampler) device = xm.xla_device() lr = 2e-5 * xm.xrt_world_size() #can make changes here num_train_steps = int( len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS) model = BERTBaseUncased("/home/nizamphoenix/bert-base-uncased/").to(device) optimizer = AdamW(model.parameters(), lr=lr, eps=1e-8) #eps = 1e-8: to prevent any division by zero scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=num_train_steps) for epoch in range(EPOCHS): para_loader = pl.ParallelLoader(train_data_loader, [device]) train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler) para_loader = pl.ParallelLoader(valid_data_loader, [device]) o, t = eval_loop_fn(para_loader.per_device_loader(device), model, device) spear = [] for jj in range(t.shape[1]): p1 = list(t[:, jj]) p2 = list(o[:, jj]) coef, _ = np.nan_to_num(stats.spearmanr(p1, p2)) spear.append(coef) spear = np.mean(spear) xm.master_print(f"epoch={epoch},spearman={spear}") xm.save(model.state_dict(), "model3.bin") #change every time
def train_mnist(): torch.manual_seed(1) """ tpu 를 쓴다하면 dataset 에 할 일 train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) """ def get_dataset(): norm = transforms.Normalize((0.1307,), (0.3081,)) train_dataset = datasets.MNIST( FLAGS['datadir'], train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), norm])) test_dataset = datasets.MNIST( FLAGS['datadir'], train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), norm])) return train_dataset, test_dataset # Using the serial executor avoids multiple processes to # download the same data. train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS['batch_size'], sampler=train_sampler, num_workers=FLAGS['num_workers'], drop_last=True) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=FLAGS['num_workers'], drop_last=True) # Scale learning rate to world size lr = FLAGS['learning_rate'] * xm.xrt_world_size() # Get loss function, optimizer, and model """ tpu 쓴다하면 device 가 device = xm.xla_device() model = xmp.MpModelWrapper(MNIST()).to(device) """ device = xm.xla_device() model = WRAPPED_MODEL.to(device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum']) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() # tpu 쓴다하면 optimizer 에 xm.optimizer_step(optimizer) xm.optimizer_step(optimizer) tracker.add(FLAGS['batch_size']) if x % FLAGS['log_steps'] == 0: print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples print('[xla:{}] Accuracy={:.2f}%'.format( xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(1, FLAGS['num_epochs'] + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS['metrics_debug']: xm.master_print(met.metrics_report(), flush=True) return accuracy, data, pred, target
def train_loop(folds, fold): if CFG.device == 'GPU': LOGGER.info(f"========== fold: {fold} training ==========") elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info(f"========== fold: {fold} training ==========") elif CFG.nprocs == 8: xm.master_print(f"========== fold: {fold} training ==========") # ==================================================== # loader # ==================================================== trn_idx = folds[folds['fold'] != fold].index val_idx = folds[folds['fold'] == fold].index train_folds = folds.loc[trn_idx].reset_index(drop=True) valid_folds = folds.loc[val_idx].reset_index(drop=True) valid_labels = valid_folds[CFG.target_cols].values train_dataset = TrainDataset(train_folds, transform=get_transforms(data='train')) valid_dataset = TrainDataset(valid_folds, transform=get_transforms(data='valid')) if CFG.device == 'GPU': train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True, drop_last=True) valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, drop_last=False) elif CFG.device == 'TPU': train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.batch_size, sampler=train_sampler, drop_last=True, num_workers=CFG.num_workers) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, sampler=valid_sampler, drop_last=False, num_workers=CFG.num_workers) # ==================================================== # scheduler # ==================================================== def get_scheduler(optimizer): if CFG.scheduler == 'ReduceLROnPlateau': scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps) elif CFG.scheduler == 'CosineAnnealingLR': scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1) elif CFG.scheduler == 'CosineAnnealingWarmRestarts': scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1) return scheduler # ==================================================== # model & optimizer # ==================================================== if CFG.device == 'TPU': device = xm.xla_device() elif CFG.device == 'GPU': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CustomResNet200D_WLF(CFG.model_name, pretrained=False) model.load_state_dict( torch.load(CFG.student, map_location=torch.device('cpu'))['model']) model.to(device) optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False) scheduler = get_scheduler(optimizer) # ==================================================== # loop # ==================================================== criterion = nn.BCEWithLogitsLoss() best_score = 0. best_loss = np.inf for epoch in range(CFG.epochs): start_time = time.time() # train if CFG.device == 'TPU': if CFG.nprocs == 1: avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device) elif CFG.nprocs == 8: para_train_loader = pl.ParallelLoader(train_loader, [device]) avg_loss = train_fn( para_train_loader.per_device_loader(device), model, criterion, optimizer, epoch, scheduler, device) elif CFG.device == 'GPU': avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device) # eval if CFG.device == 'TPU': if CFG.nprocs == 1: avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion, device) elif CFG.nprocs == 8: para_valid_loader = pl.ParallelLoader(valid_loader, [device]) avg_val_loss, preds, valid_labels = valid_fn( para_valid_loader.per_device_loader(device), model, criterion, device) preds = idist.all_gather(torch.tensor(preds)).to('cpu').numpy() valid_labels = idist.all_gather( torch.tensor(valid_labels)).to('cpu').numpy() elif CFG.device == 'GPU': avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion, device) if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(avg_val_loss) elif isinstance(scheduler, CosineAnnealingLR): scheduler.step() elif isinstance(scheduler, CosineAnnealingWarmRestarts): scheduler.step() # scoring score, scores = get_score(valid_labels, preds) elapsed = time.time() - start_time if CFG.device == 'GPU': LOGGER.info( f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s' ) LOGGER.info( f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}' ) elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info( f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s' ) LOGGER.info( f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}' ) elif CFG.nprocs == 8: xm.master_print( f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s' ) xm.master_print( f'Epoch {epoch+1} - Score: {score:.4f} Scores: {np.round(scores, decimals=4)}' ) if score > best_score: best_score = score if CFG.device == 'GPU': LOGGER.info( f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model' ) torch.save({ 'model': model.state_dict(), 'preds': preds }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth') elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info( f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model' ) elif CFG.nprocs == 8: xm.master_print( f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model' ) xm.save({ 'model': model.state_dict(), 'preds': preds }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth') if avg_val_loss < best_loss: best_loss = avg_val_loss if CFG.device == 'GPU': LOGGER.info( f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model') torch.save({ 'model': model.state_dict(), 'preds': preds }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth') elif CFG.device == 'TPU': if CFG.nprocs == 1: LOGGER.info( f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model' ) elif CFG.nprocs == 8: xm.master_print( f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model' ) xm.save({ 'model': model.state_dict(), 'preds': preds }, OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_loss.pth') if CFG.nprocs != 8: check_point = torch.load( OUTPUT_DIR + f'{CFG.model_name}_fold{fold}_best_score.pth') for c in [f'pred_{c}' for c in CFG.target_cols]: valid_folds[c] = np.nan valid_folds[[f'pred_{c}' for c in CFG.target_cols]] = check_point['preds'] return valid_folds
def valid_fn(valid_loader, model, criterion, device): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() scores = AverageMeter() # switch to evaluation mode model.eval() trues = [] preds = [] start = end = time.time() for step, (images, labels) in enumerate(valid_loader): # measure data loading time data_time.update(time.time() - end) images = images.to(device) labels = labels.to(device) batch_size = labels.size(0) # compute loss with torch.no_grad(): _, _, y_preds = model(images) loss = criterion(y_preds, labels) losses.update(loss.item(), batch_size) # record accuracy trues.append(labels.to('cpu').numpy()) preds.append(y_preds.sigmoid().to('cpu').numpy()) if CFG.gradient_accumulation_steps > 1: loss = loss / CFG.gradient_accumulation_steps # measure elapsed time batch_time.update(time.time() - end) end = time.time() if CFG.device == 'GPU': if step % CFG.print_freq == 0 or step == (len(valid_loader) - 1): print('EVAL: [{0}/{1}] ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Elapsed {remain:s} ' 'Loss: {loss.val:.4f}({loss.avg:.4f}) '.format( step, len(valid_loader), batch_time=batch_time, data_time=data_time, loss=losses, remain=timeSince(start, float(step + 1) / len(valid_loader)), )) elif CFG.device == 'TPU': if step % CFG.print_freq == 0 or step == (len(valid_loader) - 1): xm.master_print( 'EVAL: [{0}/{1}] ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Elapsed {remain:s} ' 'Loss: {loss.val:.4f}({loss.avg:.4f}) '.format( step, len(valid_loader), batch_time=batch_time, data_time=data_time, loss=losses, remain=timeSince(start, float(step + 1) / len(valid_loader)), )) trues = np.concatenate(trues) predictions = np.concatenate(preds) return losses.avg, predictions, trues
def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device): if CFG.device == 'GPU': scaler = GradScaler() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() scores = AverageMeter() # switch to train mode model.train() start = end = time.time() global_step = 0 for step, (images, labels) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) images = images.to(device) labels = labels.to(device) batch_size = labels.size(0) if CFG.device == 'GPU': with autocast(): _, _, y_preds = model(images) loss = criterion(y_preds, labels) # record loss losses.update(loss.item(), batch_size) if CFG.gradient_accumulation_steps > 1: loss = loss / CFG.gradient_accumulation_steps scaler.scale(loss).backward() grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), CFG.max_grad_norm) if (step + 1) % CFG.gradient_accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() global_step += 1 elif CFG.device == 'TPU': _, _, y_preds = model(images) loss = criterion(y_preds, labels) # record loss losses.update(loss.item(), batch_size) if CFG.gradient_accumulation_steps > 1: loss = loss / CFG.gradient_accumulation_steps loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm) if (step + 1) % CFG.gradient_accumulation_steps == 0: xm.optimizer_step(optimizer, barrier=True) optimizer.zero_grad() global_step += 1 # measure elapsed time batch_time.update(time.time() - end) end = time.time() if CFG.device == 'GPU': if step % CFG.print_freq == 0 or step == (len(train_loader) - 1): print('Epoch: [{0}][{1}/{2}] ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Elapsed {remain:s} ' 'Loss: {loss.val:.4f}({loss.avg:.4f}) ' 'Grad: {grad_norm:.4f} ' #'LR: {lr:.6f} ' .format( epoch+1, step, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, remain=timeSince(start, float(step+1)/len(train_loader)), grad_norm=grad_norm, #lr=scheduler.get_lr()[0], )) elif CFG.device == 'TPU': if step % CFG.print_freq == 0 or step == (len(train_loader) - 1): xm.master_print('Epoch: [{0}][{1}/{2}] ' 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 'Elapsed {remain:s} ' 'Loss: {loss.val:.4f}({loss.avg:.4f}) ' 'Grad: {grad_norm:.4f} ' #'LR: {lr:.6f} ' .format( epoch+1, step, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, remain=timeSince(start, float(step+1)/len(train_loader)), grad_norm=grad_norm, #lr=scheduler.get_lr()[0], )) return losses.avg
def train_tpu(): torch.manual_seed(1) def get_dataset(): fold_number = 0 train_ = pd.read_csv(args.train_fold) train = ShopeeDataset( train_[train_['fold'] != fold_number].reset_index(drop=True)) test = ShopeeDataset( train_[train_['fold'] != fold_number].reset_index(drop=True), transform=args.test_args) return train, test # Using the serial executor avoids multiple processes # to download the same data. train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, drop_last=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=True) # Scale learning rate to num cores learning_rate = 1e-5 * xm.xrt_world_size() # Get loss function, optimizer, and model device = xm.xla_device() model = WRAPPED_MODEL.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, label) in enumerate(loader): optimizer.zero_grad() output = model(image=data, label=label, get_embedding=args.get_embeddings) loss = loss_fn(output, label) loss.backward() xm.optimizer_step(optimizer) tracker.add(args.batch_size) if x % 20 == 0: print( '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) def test_loop_fn(loader): model.eval() for x, (data, label) in enumerate(loader): output = model(image=data, label=label, get_embedding=args.get_embeddings) loss = loss_fn(output, label) if x % 20 == 0: print('[xla:{}]({}) Loss={:.5f}'.format( xm.get_ordinal(), x, loss.item()), flush=True) for epoch in range(1, args.n_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) test_loop_fn(para_loader.per_device_loader(device))
def train_fn(df): size = 1; torch.manual_seed(42) df = shuffle(df) split = np.int32(SPLIT*len(df)) val_df, train_df = df[split:], df[:split] val_df = val_df.reset_index(drop=True) val_set = QuoraDataset(val_df, tokenizer) val_sampler = DistributedSampler(val_set, num_replicas=8, rank=xm.get_ordinal(), shuffle=True) train_df = train_df.reset_index(drop=True) train_set = QuoraDataset(train_df, tokenizer) train_sampler = DistributedSampler(train_set, num_replicas=8, rank=xm.get_ordinal(), shuffle=True) val_loader = DataLoader(val_set, VAL_BATCH_SIZE, sampler=val_sampler, num_workers=0, drop_last=True) train_loader = DataLoader(train_set, BATCH_SIZE, sampler=train_sampler, num_workers=0, drop_last=True) device = xm.xla_device() network = Roberta().to(device) optimizer = Adam([{'params': network.roberta.parameters(), 'lr': LR[0]*size}, {'params': network.dense_output.parameters(), 'lr': LR[1]*size}]) val_losses, val_f1s = [], [] train_losses, train_f1s = [], [] start = time.time() xm.master_print("STARTING TRAINING ...\n") for epoch in range(EPOCHS): batch = 1 network.train() fonts = (fg(48), attr('reset')) xm.master_print(("EPOCH %s" + str(epoch+1) + "%s") % fonts) val_parallel = pl.ParallelLoader(val_loader, [device]).per_device_loader(device) train_parallel = pl.ParallelLoader(train_loader, [device]).per_device_loader(device) for train_batch in train_parallel: train_targ, train_in, train_att = train_batch network = network.to(device) train_in = train_in.to(device) train_att = train_att.to(device) train_targ = train_targ.to(device) train_preds = network.forward(train_in, train_att) train_loss = bce(train_preds, train_targ)/len(train_preds) train_f1 = f1_score(train_preds, train_targ.squeeze(dim=1)) optimizer.zero_grad() train_loss.backward() xm.optimizer_step(optimizer) end = time.time() batch = batch + 1 is_print = batch % 10 == 1 f1 = np.round(train_f1.item(), 3) if is_print: print_metric(f1, batch, None, start, end, metric="F1", typ="Train") val_loss, val_f1, val_points = 0, 0, 0 network.eval() with torch.no_grad(): for val_batch in val_parallel: targ, val_in, val_att = val_batch targ = targ.to(device) val_in = val_in.to(device) val_att = val_att.to(device) network = network.to(device) pred = network.forward(val_in, val_att) val_points += len(targ) val_loss += bce(pred, targ).item() val_f1 += f1_score(pred, targ.squeeze(dim=1)).item()*len(pred) end = time.time() val_f1 /= val_points val_loss /= val_points f1 = xm.mesh_reduce('f1', val_f1, lambda x: sum(x)/len(x)) loss = xm.mesh_reduce('loss', val_loss, lambda x: sum(x)/len(x)) print_metric(np.round(f1, 3), None, epoch, start, end, metric="F1", typ="Val") xm.master_print("") val_f1s.append(f1); train_f1s.append(train_f1.item()) val_losses.append(loss); train_losses.append(train_loss.item()) xm.master_print("ENDING TRAINING ...") xm.save(network.state_dict(), MODEL_SAVE_PATH); del network; gc.collect() metric_lists = [val_losses, train_losses, val_f1s, train_f1s] metric_names = ['val_loss_', 'train_loss_', 'val_f1_', 'train_f1_'] for i, metric_list in enumerate(metric_lists): for j, metric_value in enumerate(metric_list): torch.save(metric_value, metric_names[i] + str(j) + '.pt')