def write_to_summary(summary_writer, global_step, dict_to_write={}, write_xla_metrics=False): """Writes scalars to a SummaryWriter. Optionally writes XLA perf metrics. Args: summary_writer (Tensorboard SummaryWriter): The SummaryWriter to write to. If None, no summary files will be written. global_step (int): The global step value for these data points. dict_to_write (dict, optional): Dict where key is the scalar name and value is the scalar value to be written to Tensorboard. write_xla_metrics (bool, optional): If true, this method will retrieve XLA performance metrics, parse them, and write them as scalars to Tensorboard. """ if summary_writer is None: return for k, v in dict_to_write.items(): summary_writer.add_scalar(k, v, global_step) if write_xla_metrics: metrics = mcu.parse_metrics_report(met.metrics_report()) aten_ops_sum = 0 for metric_name, metric_value in metrics.items(): if metric_name.find('aten::') == 0: aten_ops_sum += metric_value summary_writer.add_scalar(metric_name, metric_value, global_step) summary_writer.add_scalar('aten_ops_sum', aten_ops_sum, global_step)
def evaluate( self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None, ) -> Dict[str, float]: """ Run evaluation and return metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent. Args: eval_dataset: (Optional) Pass a dataset if you wish to override the one on the instance. Returns: A dict containing: - the eval loss - the potential metrics computed from the predictions """ eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self._prediction_loop(eval_dataloader, description="Evaluation") self._log(output.metrics) if self.args.tpu_metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) perplexity = math.exp(output.metrics["eval_loss"]) with open('real_time_output_perplexity_1.txt', 'a') as fw: fw.write("perplexity:{}\n".format(perplexity)) return output.metrics
def train_loop_fn(loader): tracker = xm.RateTracker() positions = torch.arange(SEQUENCE_LENGTH).long().view( 1, SEQUENCE_LENGTH).to(device) causal_mask = torch.triu( torch.ones( SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.uint8, device=device), diagonal=1).unsqueeze(0) model.train() for iteration, batch in enumerate(loader): input = batch[:, :-1].long() target = batch[:, 1:].long() loss = model(input, positions, target, batch_mask=causal_mask) loss.backward() xm.optimizer_step(optimizer) tracker.add(BATCH_SIZE) if iteration % LOG_STEPS == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format( device, iteration, loss.item() / math.log(2), tracker.rate())) if iteration % METRICS_STEP == 0: xm.master_print(met.metrics_report())
def evaluate(self, eval_datasets: Optional[Dict[str, Dataset]] = None, metric_key_prefix: str = "eval") -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init :obj:`compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (:obj:`Dataset`, `optional`): Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the :obj:`__len__` method. metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is "eval" (default) Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ results = {} if eval_datasets is None: eval_datasets = self.eval_dataset for eval_task, eval_dataset in eval_datasets.items(): self.compute_metrics = self.multi_task_compute_metrics[eval_task] model_config = self.model.config use_task_specific_params(self.model, eval_task) if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): raise ValueError("eval_dataset must implement __len__") eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self.prediction_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if self.compute_metrics is None else None, metric_key_prefix=metric_key_prefix ) if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) tasks_metric = {eval_task + "_" + k: v for k, v in output.metrics.items()} for key in sorted(tasks_metric.keys()): logger.info(f" {key} = {tasks_metric[key]}") results.update(tasks_metric) reset_config(self.model, model_config) # Computes the average metrics across all the tasks without their corresponding losses. metrics = [results[key] for key in results.keys() if "loss" not in key] results[metric_key_prefix+'_average_metrics'] = np.mean(metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, results) return results
def evaluate( self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None, ) -> Dict[str, float]: """ Run evaluation and return metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent. Args: eval_dataset: (Optional) Pass a dataset if you wish to override the one on the instance. Returns: A dict containing: - the eval loss - the potential metrics computed from the predictions """ eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self._prediction_loop(eval_dataloader, description="Evaluation") if self.args.store_best_model: self.store_best_model(output) self._log(output.metrics) if self.args.tpu_metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) return output.metrics
def xla_metrics_report(): try: import torch_xla.debug.metrics as met print(met.metrics_report()) except ImportError: return
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None): eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None try: output = self.prediction_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, ) finally: self.compute_metrics = compute_metrics if self.post_process_function is not None and self.compute_metrics is not None: eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) metrics = self.compute_metrics(eval_preds) self.log(metrics) else: metrics = {} if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) return metrics
def run_benchmark(args, pos_args): devices = xm.get_xla_supported_devices(max_devices=args.max_devices) shape = [int(x) for x in args.shape.split(',')] send_list = [] for i in range(0, len(devices)): mb = [] for j in range(0, args.prefetch): mb.append(torch.randn(*shape)) send_list.append(mb) def threadfn(i): device = devices[i] xdevices = [device] * len(send_list[i]) for n in range(0, args.test_count): with xu.TimedScope(msg='Send[{}][{}]: '.format(i, n), printfn=print): _ = torch_xla._XLAC._xla_tensors_from_aten(send_list[i], xdevices) threads = [] for i in range(0, len(devices)): t = threading.Thread(target=threadfn, args=(i,)) t.start() threads.append(t) for t in threads: t.join() print(met.metrics_report())
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init :obj:`compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (:obj:`Dataset`, `optional`): Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the :obj:`__len__` method. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. """ if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): raise ValueError("eval_dataset must implement __len__") eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self.prediction_loop(eval_dataloader, description="Evaluation") self.log(output.metrics) if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) return output
def evaluate( self, #eval_dataset: Optional[Dataset] = None, eval_datasets: Optional[list] = [None] ) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init :obj:`compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (:obj:`Dataset`, `optional`): Pass a dataset if you wish to override :obj:`self.eval_dataset`. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. """ #eval_dataloader = self.get_eval_dataloader(eval_dataset) #output = self.prediction_loop(eval_dataloader, description="Evaluation") for eval_dataset in eval_datasets: eval_dataloader = self.get_eval_dataloader(eval_dataset) output = self.prediction_loop(eval_dataloader, description="Evaluation") self.log(output.metrics) self.log(output.metrics) if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) return output.metrics
def _measure_speed(self, func) -> float: try: if self.args.is_tpu or self.args.torchscript: # run additional 10 times to stabilize compilation for tpu and torchscript logger.info( "Do inference on TPU or torchscript. Running model 5 times to stabilize compilation" ) timeit.repeat( func, repeat=1, number=5, ) # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average runtimes = timeit.repeat( func, repeat=self.args.repeat, number=10, ) if self.args.is_tpu and self.args.torch_xla_tpu_print_metrics: import torch_xla.debug.metrics as met self.print_fn(met.metrics_report()) return min(runtimes) / 10.0 except RuntimeError as e: self.print_fn(f"Doesn't fit on GPU. {e}") return "N/A"
def print_aten_ops(): "print out xla aten operations (from xla debug metrics report `torch_xla.debug.metrics`)" # import torch_xla.debug.metrics as met from io import StringIO import sys class Capturing(list): def __enter__(self): self._stdout = sys.stdout sys.stdout = self._stringio = StringIO() return self def __exit__(self, *args): self.extend(self._stringio.getvalue().splitlines()) del self._stringio # free up some memory sys.stdout = self._stdout out = met.metrics_report() if out.find("aten::"): print_now = False lines = out.split("\n") for l in lines: if print_now: print_now = False print(l) if l.find("aten::") > -1: print("needs lowering:", l) print_now = True
def evaluate( self, eval_dataset: Optional[Dataset] = None, eval_examples=None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", **gen_kwargs, ) -> Dict[str, float]: gen_kwargs = gen_kwargs.copy() gen_kwargs["max_length"] = (gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length) gen_kwargs["num_beams"] = (gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams) self._gen_kwargs = gen_kwargs eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop try: output = eval_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, ) finally: self.compute_metrics = compute_metrics if self.post_process_function is not None and self.compute_metrics is not None: eval_preds = self.post_process_function(eval_examples, eval_dataset, output) metrics = self.compute_metrics(eval_preds) # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) self.log(metrics) else: metrics = {} if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate( self.args, self.state, self.control, metrics) return metrics
def save_metrics(metrics_file=None): if metrics_file is None: metrics_file = _STEP_METRICS_FILE if metrics_file is not None: metrics_data = met.metrics_report() if metrics_file == 'STDERR': print(metrics_data, file=sys.stderr) elif metrics_file == 'STDOUT': print(metrics_data) else: with _STEP_METRICS_FILE_LOCK: with open(metrics_file, 'a') as fd: fd.write(metrics_data)
def save_metrics(metrics_file=None): if metrics_file is None: metrics_file = _get_metrics_file() if metrics_file is not None: metrics_data = '[MetricsData; step={}]\n{}\n'.format( _counter(), met.metrics_report()) if metrics_file == 'STDERR': print(metrics_data, file=sys.stderr) elif metrics_file == 'STDOUT': print(metrics_data) else: with _STEP_METRICS_FILE_LOCK: with open(metrics_file, 'a') as fd: fd.write(metrics_data)
def test_parse_real_metrics(self): print( "Testing against TPU. If this hangs, check that $XRT_TPU_CONFIG is set" ) x = torch.rand(3, 5, device=xm.xla_device()) x = torch.flatten(x, 1) x = torch.roll(x, 1, 0) x = torch.flip(x, [0, 1]) self.assertEqual(x.device.type, 'xla') metrics = met.metrics_report() self.assertTrue(metrics) data_points = mcu.get_data_points_from_metrics_reports([metrics]) self.assertIn('CompileTime__Percentile_99_sec', data_points.keys()) self.assertIn('CompileTime__TotalSamples', data_points.keys())
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None try: output = self.prediction_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, ) finally: self.compute_metrics = compute_metrics # We might have removed columns from the dataset so we put them back. if isinstance(eval_dataset, datasets.Dataset): eval_dataset.set_format( type=eval_dataset.format["type"], columns=list(eval_dataset.features.keys()), ) if self.post_process_function is not None and self.compute_metrics is not None: eval_preds = self.post_process_function( eval_examples, eval_dataset, output.predictions, self.args ) #print(eval_preds) metrics = self.compute_metrics(eval_preds) self.log(metrics) else: metrics = {} if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate( self.args, self.state, self.control, metrics ) #print(matrics) return metrics
def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None): eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples process_start = time.time() # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None try: output = self.prediction_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, ) finally: self.compute_metrics = compute_metrics # We might have removed columns from the dataset so we put them back. if isinstance(eval_dataset, datasets.Dataset): eval_dataset.set_format(type=eval_dataset.format["type"], columns=list(eval_dataset.features.keys())) if self.post_process_function is not None and self.compute_metrics is not None: eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) metrics = self.compute_metrics(eval_preds) #_compute(eval_preds.predictions, eval_preds.label_ids, self.rank_logger) process_end = time.time() self.rank_logger.info('total_accuracy:{}'.format(metrics['f1'] / 100)) self.rank_logger.info('avg_ips:{}'.format(self.num_examples(eval_dataloader) / (process_end - process_start - self.warmup_duration))) self.rank_logger.info('test_end') self.log(metrics) else: metrics = {} if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) return metrics
def train_model_xla(net, batch_size, lr, num_epochs, log_steps=20, metrics_debug=False): torch.manual_seed(1) train_loader, test_loader = load_cifar_10_xla(batch_size) # Scale learning rate to num cores lr = lr * xm.xrt_world_size() # Get loss function, optimizer, and model device = xm.xla_device() net = net.to(device) optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) loss_fn = nn.CrossEntropyLoss() # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(1, num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device), net, optimizer, loss_fn, batch_size, log_steps) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy, data, pred, target = test_loop_fn( para_loader.per_device_loader(device), net) if metrics_debug: xm.master_print(met.metrics_report(), flush=True) return accuracy, data, pred, target
def tpu_evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): model.eval() criterion.eval() cnt = 0 total = len(data_loader) for samples, targets in data_loader: print('test') samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] outputs = model(samples) loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict loss = loss_dict['loss_giou'] xm.master_print('Number: {}/{}, Loss:{}'.format( cnt + 1, total, loss.item())) cnt += 1 xm.master_print(met.metrics_report()) exit()
def run(self): bench_name = self._get_parent_class().__name__ try: self.setup() # Do one warmup run. self.bench() except Exception as e: xu.eprint('Failed running benchmark "{}": {}'.format( bench_name, e)) return try: start = time.time() now = start count = 0 while self.test_time > (now - start): self.bench() count += 1 now = time.time() print('{}: {:.3f}ms per loop'.format( bench_name, 1000.0 * (now - start) / count)) xu.get_print_fn()(met.metrics_report()) except Exception as e: xu.eprint('Failed running benchmark "{}": {}'.format( bench_name, e))
def train(self, model_path: Optional[str] = None): """ Main training entry point. Args: model_path: (Optional) Local path to model if model to train has been instantiated from a local path If present, we will try reloading the optimizer/scheduler states from there. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs optimizer, scheduler = self.get_optimizers(num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if (model_path is not None and os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(os.path.join(model_path, "scheduler.pt"))): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)) scheduler.load_state_dict( torch.load(os.path.join(model_path, "scheduler.pt"))) model = self.model if self.args.fp16: if not is_apex_available(): raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize( model, optimizer, opt_level=self.args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if self.args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank], output_device=self.args.local_rank, find_unused_parameters=True, ) if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size( ) else: total_train_batch_size = (self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)) logger.info("***** Running training *****") logger.info(" Num examples = %d", self.num_examples(train_dataloader)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) self.global_step = 0 self.epoch = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = self.global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = self.global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 logger.info(" Starting fine-tuning.") tr_loss = 0.0 logging_loss = 0.0 model.zero_grad() train_iterator = trange(epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()) for epoch in train_iterator: if isinstance(train_dataloader, DataLoader) and isinstance( train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader( train_dataloader, [self.args.device]).per_device_loader(self.args.device) epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) else: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(model, inputs, optimizer) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if self.args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), self.args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) if is_torch_tpu_available(): xm.optimizer_step(optimizer) else: optimizer.step() scheduler.step() model.zero_grad() self.global_step += 1 self.epoch = epoch + (step + 1) / len(epoch_iterator) if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (self.global_step == 1 and self.args.logging_first_step): logs: Dict[str, float] = {} logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps # backward compatibility for pytorch schedulers logs["learning_rate"] = ( scheduler.get_last_lr()[0] if version.parse(torch.__version__) >= version.parse("1.4") else scheduler.get_lr()[0]) logging_loss = tr_loss self._log(logs) if self.args.evaluate_during_training: self.evaluate() if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.model else: assert model is self.model # Save model checkpoint output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") self.save_model(output_dir) if self.is_world_master(): self._rotate_checkpoints() if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) elif self.is_world_master(): torch.save( optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save( scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) if self.args.max_steps > 0 and self.global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and self.global_step > self.args.max_steps: train_iterator.close() break if self.args.tpu_metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if self.tb_writer: self.tb_writer.close() logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(self.global_step, tr_loss / self.global_step)
def train_mnist(flags, state_dict): if flags.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=60000 // flags.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(flags.batch_size, 1, 28, 28), torch.zeros(flags.batch_size, dtype=torch.int64)), sample_count=10000 // flags.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(os.path.join(flags.datadir, str(xm.get_ordinal())), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=flags.batch_size, sampler=train_sampler, drop_last=flags.drop_last, shuffle=False if train_sampler else True, num_workers=flags.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=flags.batch_size, drop_last=flags.drop_last, shuffle=False, num_workers=flags.num_workers) # Scale learning rate to num cores lr = flags.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST() model.load_state_dict(state_dict) model = model.to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer), run_async=FLAGS.async_closures) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] accuracy = 100.0 * correct.item() / total_samples # accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if flags.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler, test_sampler = None, None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')().to(device) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) def test_loop_fn(loader, epoch): total_samples, correct = 0, 0 model.eval() for step, (data, target) in enumerate(loader): output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] if step % FLAGS.log_steps == 0: xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step)) accuracy = 100.0 * correct.item() / total_samples accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) return accuracy train_device_loader = pl.MpDeviceLoader(train_loader, device) test_device_loader = pl.MpDeviceLoader(test_loader, device) accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, FLAGS.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format( epoch, test_utils.now())) train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader, epoch) xm.master_print('Epoch {} test end {}, Accuracy={:.2f}'.format( epoch, test_utils.now(), accuracy)) max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary(writer, epoch, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
def train(args, train_dataset, model, tokenizer, disable_logging=False): """ Train the model """ if xm.is_master_ordinal(): # Only master writes to Tensorboard tb_writer = SummaryWriter(args.tensorboard_logdir) train_sampler = get_sampler(train_dataset) dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // (len(dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len(dataloader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(dataloader) * args.train_batch_size) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per TPU core = %d", args.train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", (args.train_batch_size * args.gradient_accumulation_steps * xm.xrt_world_size()), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 loss = None model.zero_grad() train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=disable_logging) set_seed(args.seed) # Added here for reproductibility (even between python 2 and 3) for epoch in train_iterator: # tpu-comment: Get TPU parallel loader which sends data to TPU in background. train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging) for step, batch in enumerate(epoch_iterator): # Save model checkpoint. if args.save_steps > 0 and global_step % args.save_steps == 0: output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) logger.info("Saving model checkpoint to %s", output_dir) if xm.is_master_ordinal(): if not os.path.exists(output_dir): os.makedirs(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) # Barrier to wait for saving checkpoint. xm.rendezvous("mid_training_checkpoint") # model.save_pretrained needs to be called by all ordinals model.save_pretrained(output_dir) model.train() inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": # XLM, DistilBERT and RoBERTa don't use segment_ids inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in transformers (see doc) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) xm.optimizer_step(optimizer) scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics. results = {} if args.evaluate_during_training: results = evaluate(args, model, tokenizer, disable_logging=disable_logging) loss_scalar = loss.item() logger.info( "global_step: {global_step}, lr: {lr:.6f}, loss: {loss:.3f}".format( global_step=global_step, lr=scheduler.get_lr()[0], loss=loss_scalar ) ) if xm.is_master_ordinal(): # tpu-comment: All values must be in CPU and not on TPU device for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", loss_scalar, global_step) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if xm.is_master_ordinal(): tb_writer.close() return global_step, loss.item()
def evaluate(args, model, tokenizer, prefix="", disable_logging=False): """Evaluate the model""" if xm.is_master_ordinal(): # Only master writes to Tensorboard tb_writer = SummaryWriter(args.tensorboard_logdir) # Loop to handle MNLI double evaluation (matched, mis-matched) eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,) results = {} for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) eval_sampler = get_sampler(eval_dataset) if not os.path.exists(eval_output_dir): os.makedirs(eval_output_dir) dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False) eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(dataloader) * args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None for batch in tqdm(eval_dataloader, desc="Evaluating", disable=disable_logging): model.eval() with torch.no_grad(): inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": # XLM, DistilBERT and RoBERTa don't use segment_ids inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None outputs = model(**inputs) batch_eval_loss, logits = outputs[:2] eval_loss += batch_eval_loss nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs["labels"].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) # tpu-comment: Get all predictions and labels from all worker shards of eval dataset preds = xm.mesh_reduce("eval_preds", preds, np.concatenate) out_label_ids = xm.mesh_reduce("eval_out_label_ids", out_label_ids, np.concatenate) eval_loss = eval_loss / nb_eval_steps if args.output_mode == "classification": preds = np.argmax(preds, axis=1) elif args.output_mode == "regression": preds = np.squeeze(preds) result = compute_metrics(eval_task, preds, out_label_ids) results.update(result) results["eval_loss"] = eval_loss.item() output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") if xm.is_master_ordinal(): with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(results.keys()): logger.info(" %s = %s", key, str(results[key])) writer.write("%s = %s\n" % (key, str(results[key]))) tb_writer.add_scalar(f"{eval_task}/{key}", results[key]) if args.metrics_debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) if xm.is_master_ordinal(): tb_writer.close() return results
def train_mnist(): torch.manual_seed(1) """ tpu 를 쓴다하면 dataset 에 할 일 train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) """ def get_dataset(): norm = transforms.Normalize((0.1307,), (0.3081,)) train_dataset = datasets.MNIST( FLAGS['datadir'], train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), norm])) test_dataset = datasets.MNIST( FLAGS['datadir'], train=False, download=True, transform=transforms.Compose( [transforms.ToTensor(), norm])) return train_dataset, test_dataset # Using the serial executor avoids multiple processes to # download the same data. train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS['batch_size'], sampler=train_sampler, num_workers=FLAGS['num_workers'], drop_last=True) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=FLAGS['num_workers'], drop_last=True) # Scale learning rate to world size lr = FLAGS['learning_rate'] * xm.xrt_world_size() # Get loss function, optimizer, and model """ tpu 쓴다하면 device 가 device = xm.xla_device() model = xmp.MpModelWrapper(MNIST()).to(device) """ device = xm.xla_device() model = WRAPPED_MODEL.to(device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum']) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() # tpu 쓴다하면 optimizer 에 xm.optimizer_step(optimizer) xm.optimizer_step(optimizer) tracker.add(FLAGS['batch_size']) if x % FLAGS['log_steps'] == 0: print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples print('[xla:{}] Accuracy={:.2f}%'.format( xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(1, FLAGS['num_epochs'] + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS['metrics_debug']: xm.master_print(met.metrics_report(), flush=True) return accuracy, data, pred, target
self.b = ForTest1(self, a) xdata = { 2: (11, ['a', 'b'], 17), 'w': [12, 'q', 12.33], 17.09: set(['a', 'b', 21]), } data = ForTest2(xdata) wids = [] def convert(x): wids.append(id(x)) return x xu.for_each_instance_rewrite(data, lambda x: isinstance(x, (int, str, float)), convert) self.assertEqual(len(wids), 11) if __name__ == '__main__': torch.set_default_tensor_type('torch.FloatTensor') torch.manual_seed(42) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) test = unittest.main(verbosity=FLAGS.verbosity, exit=False) if xu.getenv_as('METRICS_DEBUG', bool, defval=False): print(met.metrics_report()) sys.exit(0 if test.result.wasSuccessful() else 1)
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers) # Scale learning rate to num cores lr = FLAGS.lr * xm.xrt_world_size() device = xm.xla_device() model = MNIST().to(device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS.metrics_debug: print(met.metrics_report()) return accuracy
def main(config_file='config/bert_config.json'): """Main method for training. Args: config_file: in config dir """ global datasets # 0. Load config and mkdir with open(config_file) as fin: config = json.load(fin, object_hook=lambda d: SimpleNamespace(**d)) get_path(os.path.join(config.model_path, config.experiment_name)) get_path(config.log_path) if config.model_type in ['rnn', 'lr', 'cnn']: # build vocab for rnn build_vocab(file_in=config.all_train_file_path, file_out=os.path.join(config.model_path, 'vocab.txt')) # 1. Load data data = Data(vocab_file=os.path.join(config.model_path, 'vocab.txt'), max_seq_len=config.max_seq_len, model_type=config.model_type, config=config) def load_dataset(): datasets = data.load_train_and_valid_files( train_file=config.train_file_path, valid_file=config.valid_file_path) return datasets if config.serial_load: datasets = SERIAL_EXEC.run(load_dataset) else: datasets = load_dataset() train_set, valid_set_train, valid_set_valid = datasets if torch.cuda.is_available(): device = torch.device('cuda') # device = torch.device('cpu') # torch.distributed.init_process_group(backend="nccl") # sampler_train = DistributedSampler(train_set) sampler_train = RandomSampler(train_set) else: device = torch.device('cpu') sampler_train = RandomSampler(train_set) # TPU device = xm.xla_device() sampler_train = torch.utils.data.distributed.DistributedSampler( train_set, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) data_loader = { 'train': DataLoader(train_set, sampler=sampler_train, batch_size=config.batch_size), 'valid_train': DataLoader(valid_set_train, batch_size=config.batch_size, shuffle=False), 'valid_valid': DataLoader(valid_set_valid, batch_size=config.batch_size, shuffle=False) } # 2. Build model # model = MODEL_MAP[config.model_type](config) model = WRAPPED_MODEL #load model states. # if config.trained_weight: # model.load_state_dict(torch.load(config.trained_weight)) model.to(device) if torch.cuda.is_available(): model = model # model = torch.nn.parallel.DistributedDataParallel( # model, find_unused_parameters=True) # 3. Train trainer = Trainer(model=model, data_loader=data_loader, device=device, config=config) # best_model_state_dict = trainer.train() if config.model_type == 'bert': no_decay = ['bias', 'gamma', 'beta'] optimizer_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay_rate': 0.01 }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay_rate': 0.0 }] optimizer = AdamW(optimizer_parameters, lr=config.lr, betas=(0.9, 0.999), weight_decay=1e-8, correct_bias=False) else: # rnn optimizer = Adam(model.parameters(), lr=config.lr) # if config.model_type == 'bert': # scheduler = get_linear_schedule_with_warmup( # optimizer, # num_warmup_steps=config.num_warmup_steps, # num_training_steps=config.num_training_steps) # else: # rnn # scheduler = get_constant_schedule(optimizer) criterion = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, batch in enumerate(loader): # batch = tuple(t.to(self.device) for t in batch) output = model(*batch[:-1]) # the last one is label loss = criterion(output, batch[-1]) loss.backward() # xm.optimizer_step(optimizer) # optimizer.zero_grad() tracker.add(FLAGS.batch_size) if (x + 1) % config.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。 xm.optimizer_step(optimizer) optimizer.zero_grad() if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print( '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None tracker = xm.RateTracker() for x, batch in enumerate(loader): output = model(*batch[:-1]) # the last one is label target = batch[-1] # pred = output.max(1, keepdim=True)[1] # correct += pred.eq(target.view_as(pred)).sum().item() for i in range(len(output)): logits = output[i] pred = int(torch.argmax(logits, dim=-1)) if pred == target[i]: correct += 1 total_samples += len(output) if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print( '[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, correct * 1.0 / total_samples, tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) accuracy = 100.0 * correct / total_samples if xm.get_ordinal() == 0: print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(FLAGS.num_epoch): para_loader = pl.ParallelLoader(data_loader['train'], [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) # para_loader = pl.ParallelLoader(data_loader['valid_train'], [device]) # accuracy_train, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) para_loader = pl.ParallelLoader(data_loader['valid_valid'], [device]) accuracy_valid, data, pred, target = test_loop_fn( para_loader.per_device_loader(device)) xm.master_print("Finished test epoch {}, valid={:.2f}".format( epoch, accuracy_valid)) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) # 4. Save model # if xm.get_ordinal() == 0: # # if epoch==FLAGS.num_epoch-1: # # WRAPPED_MODEL.to('cpu') # torch.save(WRAPPED_MODEL.state_dict(), os.path.join( # config.model_path, config.experiment_name, # config.model_type + '-' + str(epoch + 1) + '.bin')) # xm.master_print('saved model.') # WRAPPED_MODEL.to(device) return accuracy_valid