def barrier(self) -> None: """ Synchronizes all processes. This collective blocks processes until the all runs enter the function. """ xm.rendezvous("barrier")
def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) # clean up the GPU cache used for the benchmark # https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/4 if self.steps == 0 and self.gpu: torch.cuda.empty_cache() current_loss = float(trainer.progress_bar_dict["loss"]) self.steps += 1 avg_loss = 0 if current_loss == current_loss: # don't add if current_loss is NaN avg_loss = self.average_loss(current_loss, self.prev_avg_loss, self.smoothing) self.prev_avg_loss = avg_loss desc = f"Loss: {current_loss:.3f} — Avg: {avg_loss:.3f}" if self.steps % self.progress_bar_refresh_rate == 0: if self.gpu: desc += f" — GPU Mem: {list(get_gpu_memory_map().values())[0]} MB" self.main_progress_bar.update(self.progress_bar_refresh_rate) self.main_progress_bar.set_description(desc) if self.enabled: if self.save_every > 0 and self.steps % self.save_every == 0: if pl_module.hparams["tpu"]: xm.rendezvous("save_model") self.save_pytorch_model(trainer, pl_module) if (not pl_module.hparams["tpu"] and self.generate_every > 0 and self.steps % self.generate_every == 0): self.generate_sample_text(trainer, pl_module)
def save(data, path, master_only=True, global_master=False): """Saves the input data into a file. The saved data is transferred to PyTorch CPU device before being saved, so a following `torch.load()` will load CPU data. Care must be taken when working with views. Instead of saving views it's recommended that you recreate them after the tensors have been loaded and moved to their destination device(s). Args: data: The input data to be saved. Any nested combination of Python objects (list, tuples, sets, dicts, ...). path: The destination file for the data saving operation. If `master_only` is ``False`` the path must point to different destinations as otherwise all the writes from the same host will override each other. master_only (bool, optional): Whether only the master device should save the data. If False, the `path` argument should be a different path for each of the ordinals taking part to the replication, otherwise all the replicas on the same host will be writing to the same location. Default: True global_master (bool, optional): When ``master_only`` is ``True`` this flag controls whether every host's master (if ``global_master`` is ``False``) saves the content, or only the global master (ordinal 0). Default: False """ should_write_data = not master_only or xm.is_master_ordinal( local=not global_master) ref_data = _rewrite_data(_get_tensors_folder(path), data, should_write_data) if should_write_data: torch.save(ref_data, path) xm.rendezvous('torch_xla.utils.serialization.save')
def reduce_early_stopping_decision(self, should_stop: bool) -> bool: should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) stop = xm.mesh_reduce('stop_signal', should_stop, sum) rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") should_stop = int(stop.item()) == self.world_size return should_stop
def distributed_init(args): if args.distributed_world_size == 1: raise ValueError( 'Cannot initialize distributed with distributed_world_size=1') if not getattr(args, 'tpu', False): if torch.distributed.is_initialized(): warnings.warn( 'Distributed is already initialized, cannot initialize twice!') else: logger.info('distributed init (rank {}): {}'.format( args.distributed_rank, args.distributed_init_method, )) dist.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank, ) logger.info('initialized host {} as rank {}'.format( socket.gethostname(), args.distributed_rank, )) # perform a dummy all-reduce to initialize the NCCL communicator if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) args.distributed_rank = torch.distributed.get_rank() else: import torch_xla.core.xla_model as xm assert xm.xrt_world_size() == args.distributed_world_size args.device_id = xm.get_local_ordinal() args.distributed_rank = xm.get_ordinal() xm.rendezvous('distributed_init') # wait for all workers xm.mark_step() if is_master(args): logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) if args.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_cuda_manual_seed, ) except ImportError: raise ImportError('\n\nPlease install the megatron submodule:' '\n\n git submodule update --init ' 'fairseq/model_parallel/megatron') initialize_model_parallel(args.model_parallel_size) model_parallel_cuda_manual_seed(args.seed) model_part_number = get_model_parallel_rank() args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number) return args.distributed_rank
def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" if self.tpu: import torch_xla.core.xla_model as xm xm.rendezvous("valid_step") # wait for all workers xm.mark_step() with torch.no_grad(): self.model.eval() self.criterion.eval() sample = self._prepare_sample(sample) if sample is None: sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: if self._dummy_batch == "DUMMY": self._dummy_batch = sample is_dummy_batch = False try: _loss, sample_size, logging_output = self.task.valid_step( sample, self.model, self.criterion) except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if not raise_oom: logger.warning( "ran out of memory in validation step, retrying batch" ) for p in self.model.parameters(): if p.grad is not None: p.grad = None # free some memory if self.cuda: torch.cuda.empty_cache() return self.valid_step(sample, raise_oom=True) raise e logging_outputs = [logging_output] if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0.0 # gather logging outputs from all replicas if self.data_parallel_world_size > 1: logging_outputs, (sample_size, ) = self._aggregate_logging_outputs( logging_outputs, sample_size, ignore=is_dummy_batch, ) # log validation stats logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) return logging_output
def run_validation(e: Engine): if distributed: torch.cuda.synchronize(device) if use_tpu: xm.rendezvous('validate_{}'.format(e.state.iteration)) valid_it = valid_dl.per_device_loader(device) evaluator.run(valid_it, epoch_length=len(valid_dl)) else: evaluator.run(valid_dl)
def tpu_data_loader(args, itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl xm.rendezvous('tpu_data_loader') # wait for all workers xm.mark_step() device = utils.get_tpu_device(args) return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), start=getattr(itr, 'n', 0), total=len(itr), )
def xla_run_method(rank, fit_method, learner_args, add_args, fit_args, ctrl_args): "run fit method on spawned process" sync_valid = True learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args) fit_args = setup_fit_cbs(rank, fit_args) fit_method(learner, **fit_args) xm.rendezvous('xla_run_method') learner.save('_xla_tmp_model', rendezvous=False) xm.mark_step()
def wait_for_everyone(): """ Introduces a blocking point in the script, making sure all processes have reached this point before continuing. Warning:: Make sure all processes will reach this instruction otherwise one of your processes will hang forever. """ if AcceleratorState().distributed_type == DistributedType.MULTI_GPU: torch.distributed.barrier() elif AcceleratorState().distributed_type == DistributedType.TPU: xm.rendezvous("accelerate.utils.wait_for_everyone")
def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, split_D=False): # If training G, enable grad tape with torch.set_grad_enabled(train_G): # Get Generator output given noise G_z = self.G(z, self.G.shared(gy)) # Cast as necessary if self.G.fp16 and not self.D.fp16: G_z = G_z.float() if self.D.fp16 and not self.G.fp16: G_z = G_z.half() # Split_D means to run D once with real data and once with fake, # rather than concatenating along the batch dimension. if split_D: D_fake = self.D(G_z, gy) if x is not None: D_real = self.D(x, dy) return D_fake, D_real else: if return_G_z: return D_fake, G_z else: return D_fake # If real data is provided, concatenate it with the Generator's output # along the batch dimension for improved efficiency. else: D_input = torch.cat([G_z, x], 0) if x is not None else G_z self.counter += 1 if xm.is_master_ordinal() and (self.counter % 50 == 0): torchvision.utils.save_image(D_input[:64].float().cpu(), 'd_inp.png', nrow=int(D_input.shape[0]**0.5), normalize=True) xm.rendezvous('cont....') D_class = torch.cat([gy, dy], 0) if dy is not None else gy # Get Discriminator output D_out = self.D(D_input, D_class) if x is not None: return torch.split( D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real else: if return_G_z: return D_out, G_z else: return D_out
def tpu_data_loader(itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl from fairseq.data import iterators xm.rendezvous("tpu_data_loader") # wait for all workers xm.mark_step() device = xm.xla_device() return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), start=getattr(itr, "n", 0), total=len(itr), )
def load_and_cache_examples(args, task, tokenizer, evaluate=False): if not xm.is_master_ordinal(): xm.rendezvous("load_and_cache_examples") processor = processors[task]() output_mode = output_modes[task] cached_features_file = os.path.join( args.cache_dir, "cached_{}_{}_{}_{}".format( "dev" if evaluate else "train", list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length), str(task), ), ) # Load data features from cache or dataset file if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: logger.info("Creating features from dataset file at %s", args.data_dir) label_list = processor.get_labels() if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta"]: # HACK(label indices are swapped in RoBERTa pretrained model) label_list[1], label_list[2] = label_list[2], label_list[1] examples = ( processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) ) features = convert_examples_to_features( examples, tokenizer, max_length=args.max_seq_length, label_list=label_list, output_mode=output_mode, ) logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) if xm.is_master_ordinal(): xm.rendezvous("load_and_cache_examples") # Convert to Tensors and build dataset all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) if output_mode == "classification": all_labels = torch.tensor([f.label for f in features], dtype=torch.long) elif output_mode == "regression": all_labels = torch.tensor([f.label for f in features], dtype=torch.float) dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) return dataset
def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir logger.info("Saving model checkpoint to %s", output_dir) if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): raise ValueError("Trainer.model appears to not be a PreTrainedModel") xm.rendezvous("saving_checkpoint") self.model.save_pretrained(output_dir)
def begin_epoch(self, epoch): """Called at the beginning of each epoch.""" logger.info("begin training epoch {}".format(epoch)) if self.quantizer is not None: self.quantizer.begin_epoch(epoch) # task specific setup per epoch self.task.begin_epoch(epoch, self.get_model()) if self.tpu: import torch_xla.core.xla_model as xm xm.rendezvous('begin_epoch') # wait for all workers xm.mark_step()
def synchronize(message="sync-workers"): if is_xla(): xm.rendezvous(message) elif not dist.is_available(): return if not dist.is_nccl_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier()
def _mp_fn(index): ordinal = xm.get_ordinal() print('Core {} waiting for rendezvous ...'.format(ordinal)) replicas, gid = _get_replica_group(index) data = xm.rendezvous( 'rendezvous_test.{}'.format(gid), 'ORD={}'.format(ordinal).encode('utf-8'), replicas=replicas) print('Core {} got rendezvous!'.format(ordinal)) for i in range(0, len(data)): idata = data[i].decode('utf-8') m = re.match(r'ORD=(\d+)', idata) assert m, 'Bad payload format: {}'.format(idata) xordinal = int(m.group(1)) assert replicas[i] == xordinal, 'Payload {} got ordinal {}'.format( replicas[i], xordinal) xm.rendezvous('_mp_fn.exit')
def simple_map_fn(index, flags): # Sets a common random seed - both for initialization and ensuring graph is the same torch.manual_seed(1234) # Acquires the (unique) Cloud TPU core corresponding to this process's index device = xm.xla_device() # Creates a tensor on this process's device t = torch.randn((2, 2), device=device) print("device: ", device, " str: ", str(device), " real: ", xm.xla_real_devices([str(device)])) print("master: ", xm.is_master_ordinal()) print("Process", index, "is using", xm.xla_real_devices([str(device)])[0]) # Barrier to prevent master from exiting before workers connect. xm.rendezvous('init')
def load_cifar_10_xla(batch_size, root='/tmp/cifar10'): """Download and load the CIFAR-10 dataset.""" if not xm.is_master_ordinal(): # Barrier: Wait until master is done downloading xm.rendezvous('download_only_once') # Get and shard dataset into dataloaders norm = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), norm ]) transform_test = transforms.Compose([transforms.ToTensor(), norm]) cifar10_train = datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train) cifar10_test = datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test) if xm.is_master_ordinal(): # Barrier: Master done downloading, other workers can proceed xm.rendezvous('download_only_once') train_sampler = DistributedSampler(cifar10_train, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = DataLoader(cifar10_train, batch_size=batch_size, sampler=train_sampler, num_workers=4, drop_last=True) test_loader = DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True) return train_loader, test_loader
def broadcast_xla_master_model_param(model): logger.info( "Broadcasting XLA model parameters and buffers from master process ..." ) parameters_and_buffers = [] for p in chain(model.parameters(), model.buffers()): # Set all params in non-master devices to zero so that all_reduce is equivalent # to broadcasting parameters from master to other devices. if not is_main(): zero = torch.tensor(0, dtype=p.data.dtype, device=p.data.device) p.data.mul_(zero) parameters_and_buffers.append(p.data) xm.wait_device_ops() xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers) xm.mark_step() xm.rendezvous("mmf.trainers.core.device.broadcast_xla_master_model_param") logger.info("Done!")
def xla_run_lr_find(rank, learner_args, add_args, lr_find_args, ctrl_args): xm.rendezvous('start_xla_run_lr_find') # print(f'xla {rank} : start run lrfind') sync_valid = True learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args) num_it = lr_find_args['num_it'] n_epoch = num_it // len(learner.dls.train) + 1 learner.opt = None learner.create_opt() cb = XLALRFinder(**lr_find_args) skip_valid_cb = SkipValidationCallback() synced_cancel_cb = SyncedCancelCallback() with learner.no_logging(): learner.fit(n_epoch, cbs=[cb, skip_valid_cb, synced_cancel_cb])
def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) # clean up the GPU cache used for the benchmark # https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/4 if self.steps == 0 and self.gpu: torch.cuda.empty_cache() current_loss = float(trainer.progress_bar_dict["loss"]) self.steps += 1 avg_loss = 0 if current_loss == current_loss: # don't add if current_loss is NaN avg_loss = self.average_loss(current_loss, self.prev_avg_loss, self.smoothing) self.prev_avg_loss = avg_loss desc = f"Loss: {current_loss:.3f} — Avg: {avg_loss:.3f}" # Ajout serge Save loss if self.steps % self.save_every == 0 or self.steps == 1: print("Enregistrement du loss dans loss.txt") fichier = "./loss.txt" data = str(self.steps) + " " + desc + "\n" with open(fichier, "a") as fd: fd.write(data) fd.close() if self.steps % self.progress_bar_refresh_rate == 0: if self.gpu: desc += f" — GPU Mem: {get_gpu_memory_map()['gpu_0']} MB" self.main_progress_bar.update(self.progress_bar_refresh_rate) self.main_progress_bar.set_description(desc) if self.enabled: if self.save_every > 0 and self.steps % self.save_every == 0: if pl_module.hparams["tpu"]: xm.rendezvous("save_model") self.save_pytorch_model(trainer, pl_module) if (not pl_module.hparams["tpu"] and self.generate_every > 0 and self.steps % self.generate_every == 0): self.generate_sample_text(trainer, pl_module)
def _mp_fn(index): ordinal = xm.get_ordinal() print('Core {} waiting for rendezvous ...'.format(ordinal)) data = xm.rendezvous('rendezvous_test', 'ORD={}'.format(ordinal)) print('Core {} got rendezvous!'.format(ordinal)) for i in range(0, len(data)): m = re.match(r'ORD=(\d+)', data[i]) assert m, 'Bad payload format: {}'.format(data[i]) xordinal = int(m.group(1)) assert i == xordinal, 'Payload {} got ordinal {}'.format(i, xordinal)
def _mp_fn(index, _hps, _dat_file): t.manual_seed(_hps.random_seed) # Acquires the (unique) Cloud TPU core corresponding to this process's index pre_dev_time = time.time() device = xm.xla_device() device_str = xm.xla_real_devices([str(device)])[0] elapsed = time.time() - pre_dev_time print(f'process {index} acquired {device_str} in {elapsed} seconds', file=stderr, flush=True) pre_inst_time = time.time() m = ch.Chassis(device, index, _hps, _dat_file) print(f'Created Chassis in {time.time() - pre_inst_time:3.5} seconds.', file=stderr, flush=True) xm.rendezvous('init') m.train()
def xla_run_inference(rank, learner_args, add_args, inference_args, ctrl_args): sync_valid = True learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args) pred_args, master_cbs = setup_inference_args(rank, inference_args) if rank == 0 and len(master_cbs) > 0: learner.add_cbs(master_cbs) # learner.synced_cancel.before_fit() if rank == 0: learner.sync_recorder.orig_logger = learner.logger results = learner.inner_get_preds(**pred_args) xm.rendezvous('xla_run_inference') save_pred_results(rank, results) xm.mark_step()
def sync_bn1d_no_channel_test(index): torch.manual_seed(1) bsz = 32 length = 64 t_global = torch.rand((xm.xrt_world_size() * bsz, length)) # XLA SyncBatchNorm device = xm.xla_device() t_xla = t_global[bsz * index:bsz * (index + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(length).to(device) result = run_step(sbn_xla, t_xla) # CPU BatchNorm bn_cpu = torch.nn.BatchNorm1d(length) expected = run_step(bn_cpu, t_global) cpu_result = result.cpu() assert cpu_result.allclose(expected, rtol=RTOL, atol=ATOL) assert_stats(sbn_xla.cpu(), bn_cpu) xm.rendezvous('sync_bn1d_no_channel_test') xm.master_print('sync_bn1d_no_channel_test ok')
def sync_bn3d_test(index): torch.manual_seed(1) bsz = 16 features = 32 d, h, w = 16, 32, 32 t_global = torch.rand((xm.xrt_world_size() * bsz, features, d, h, w)) # XLA SyncBatchNorm device = xm.xla_device() t_xla = t_global[bsz * index:bsz * (index + 1), ...].to(device) sbn_xla = xf.SyncBatchNorm(features).to(device) result = run_step(sbn_xla, t_xla) # CPU BatchNorm bn_cpu = torch.nn.BatchNorm3d(features) expected = run_step(bn_cpu, t_global) cpu_result = result.cpu() assert cpu_result.allclose(expected, rtol=RTOL, atol=ATOL) assert_stats(sbn_xla.cpu(), bn_cpu) xm.rendezvous('sync_bn3d_test') xm.master_print('sync_bn3d_test ok')
def save_xla_ckpt(ckpt, file_or_path): """ Similar to xm.save, but only try to convert "model" and "optimizer" in an MMF checkpoint to CPU, since they hold PyTorch tensors. Other items like lr_scheduler often cannot be saved with xm.save due to its errors in handling mappingproxy. Only save on the global main process (which is different from the default behavior of xm.save that saves a checkpoint on each node). """ should_write_data = is_main() is_full_ckpt = isinstance(ckpt, dict) and "model" in ckpt and "optimizer" in ckpt if is_full_ckpt: ckpt["model"] = xm._maybe_convert_to_cpu(ckpt["model"], convert=should_write_data) ckpt["optimizer"] = xm._maybe_convert_to_cpu(ckpt["optimizer"], convert=should_write_data) else: ckpt = xm._maybe_convert_to_cpu(ckpt, convert=should_write_data) if should_write_data: torch.save(ckpt, file_or_path) xm.rendezvous("mmf.utils.checkpoint.save_xla_ckpt")
def _mp_fn(index, temp_file): device = xm.xla_device() dd = _create_state_dict(device) xm.save(dd, temp_file) ldd = torch.load(temp_file) pdd = _get_data_str(ldd) data = xm.rendezvous('xm_save_test', pdd) if xm.get_local_ordinal() == 0: os.remove(temp_file) for i in range(1, len(data)): bio = io.BytesIO(data[i]) ildd = torch.load(bio) for k, v in ldd.items(): if isinstance(v, torch.Tensor): assert v.allclose(ildd[k]) elif isinstance(v, (list, tuple)): iv = ildd[k] for a, b in zip(v, iv): assert a.allclose(b) else: raise RuntimeError('Invalid data type')
def on_epoch_end(self, callbacks): def execute(): for callback in callbacks: callback.on_epoch_end() # reset for name, metric in self.train_eval.items(): self.train_eval[name].reset() for name, metric in self.dev_eval.items(): self.dev_eval[name].reset() if self.using_tpu: xm.rendezvous("train is done!") if xm.is_master_ordinal(): execute() xm.rendezvous("on_epoch_end is done!") else: xm.rendezvous("on_epoch_end is done!") else: execute()