def _mp_fn(index): ordinal = xm.get_ordinal() print('Core {} waiting for rendezvous ...'.format(ordinal)) data = xmp.rendezvous('rendezvous_test', 'ORD={}'.format(ordinal)) print('Core {} got rendezvous!'.format(ordinal)) for i in range(0, len(data)): m = re.match(r'ORD=(\d+)', data[i]) assert m, 'Bad payload format: {}'.format(data[i]) xordinal = int(m.group(1)) assert i == xordinal, 'Payload {} got ordinal {}'.format(i, xordinal)
def is_main_process(local_rank): """ Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on `local_rank`. """ if is_torch_tpu_available(): import torch_xla.core.xla_model as xm return xm.get_ordinal() == 0 return local_rank in [-1, 0]
def get_rank(): if is_xla(): return xm.get_ordinal() if not dist.is_available(): return 0 if not dist.is_nccl_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank()
def distributed_init(args): if not getattr(args, 'tpu', False): if torch.distributed.is_initialized(): warnings.warn( 'Distributed is already initialized, cannot initialize twice!') else: logger.info('distributed init (rank {}): {}'.format( args.distributed_rank, args.distributed_init_method, )) dist.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank, ) logger.info('initialized host {} as rank {}'.format( socket.gethostname(), args.distributed_rank, )) # perform a dummy all-reduce to initialize the NCCL communicator if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) args.distributed_rank = torch.distributed.get_rank() else: import torch_xla.core.xla_model as xm assert xm.xrt_world_size() == args.distributed_world_size args.device_id = xm.get_local_ordinal() args.distributed_rank = xm.get_ordinal() xm.rendezvous('distributed_init') # wait for all workers xm.mark_step() if is_master(args): logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) if args.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_cuda_manual_seed, ) except ImportError: raise ImportError('\n\nPlease install the megatron submodule:' '\n\n git submodule update --init ' 'fairseq/model_parallel/megatron') initialize_model_parallel(args.model_parallel_size) model_parallel_cuda_manual_seed(args.seed) model_part_number = get_model_parallel_rank() args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number) return args.distributed_rank
def make_loader(self, dataset, batch_size, phase): shuffle = True if phase == "train" else False if self.using_tpu: sampler = DistributedSampler(dataset=dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=shuffle) loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=xm.xrt_world_size(), drop_last=True) else: loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle) return loader
def build_dataloader(args, tokenizer): train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) train_sampler = RandomSampler(train_dataset) if xm.xrt_world_size() > 1: train_sampler = DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) return DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
def _wrap_dl(self, dl): if isinstance(dl, pl.PerDeviceLoader): return dl else: #dl = dl.to(self.device) dl.fake_l.num_workers = 0 # For some reason, needed for it to work (something on fastai's end). Need to investigate further distributed_dl = TPUDistributedDL( dl, xm.get_ordinal(), xm.xrt_world_size()) # Use existing distributed functionality return pl.ParallelLoader( distributed_dl, [self.device]).per_device_loader(self.device)
def process_index(self): """ The number of processes used in parallel. """ if is_torch_tpu_available(): return xm.get_ordinal() elif is_sagemaker_distributed_available(): return sm_dist.get_rank() elif self.local_rank != -1: return torch.distributed.get_rank() return 0
def _train_xla(self, rank, loader, loader_valid, num_epochs): seed_everything(self.random_state, self.deterministic) self.device = xm.xla_device() self.rank = xm.get_ordinal() self._configure_model() loader = self._configure_loader_ddp(loader) loader_valid = self._configure_loader_ddp(loader_valid, shuffle=False) self._train( loader.per_device_loader(self.device), loader_valid.per_device_loader(self.device), num_epochs)
def test_loop_fn(loader): model.eval() for x, (data, label) in enumerate(loader): output = model(image=data, label=label, get_embedding=args.get_embeddings) loss = loss_fn(output, label) if x % 20 == 0: print('[xla:{}]({}) Loss={:.5f}'.format( xm.get_ordinal(), x, loss.item()), flush=True)
def auto_add_sampler(self, dataloader, train): # do nothing when user gives a sampler dl_args = { 'dataset': dataloader.dataset, 'batch_size': dataloader.batch_size, 'shuffle': False, 'num_workers': dataloader.num_workers, 'collate_fn': dataloader.collate_fn, 'pin_memory': dataloader.pin_memory, 'drop_last': dataloader.drop_last, 'timeout': dataloader.timeout, 'worker_init_fn': dataloader.worker_init_fn } if train: if self.use_ddp or self.use_ddp2: sampler = DistributedSampler(dataloader.dataset) dl_args['shuffle'] = False elif self.use_tpu: sampler = DistributedSampler(dataloader.dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) dl_args['shuffle'] = False else: sampler = RandomSampler(dataloader.dataset) # on not train else: if self.use_tpu: sampler = DistributedSampler(dataloader.dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) dl_args['shuffle'] = False else: sampler = SequentialSampler(dataloader.dataset) dl_args['sampler'] = sampler new_dataloader = DataLoader(**dl_args) return new_dataloader
def get_dataloader(path, batch_size, sequence_length, num_workers): dataset = LazyDataset(path, sequence_length + 1) if xm.xrt_world_size() > 1: sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) else: sampler = torch.utils.data.RandomSampler(dataset) return torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)
def process_index(self): """ The index of the current process used. """ if is_torch_tpu_available(): return xm.get_ordinal() elif is_sagemaker_mp_enabled(): return smp.dp_rank() elif is_sagemaker_dp_enabled(): return sm_dist.get_rank() elif self.local_rank != -1: return torch.distributed.get_rank() return 0
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if is_torch_tpu_available() and xm.xrt_world_size() > 1: num_replicas = xm.xrt_world_size() rank = xm.get_ordinal() elif self.args.local_rank != -1: num_replicas = torch.distributed.get_world_size() rank = self.args.local_rank else: num_replicas = 1 rank = 0 return MultiTaskBatchSampler(self.dataset_sizes, self.args.train_batch_size, self.args.temperature, rank=rank, num_replicas=num_replicas)
def get_rank() -> int: """ Returns the rank of the current worker. Returns: int: ``rank`` if torch.distributed is initialized, otherwise ``-1`` """ if _is_xla_distributed_initialized(): return xm.get_ordinal() elif _is_torch_distributed_initialized(): return dist.get_rank() else: return -1
def local_process_index(self): """ The index of the local process used. """ if is_torch_tpu_available(): return xm.get_ordinal(local=True) elif is_sagemaker_mp_enabled(): return smp.local_rank() elif is_sagemaker_dp_enabled(): return sm_dist.get_rank() elif self.local_rank != -1: return self.local_rank return 0
def __init__(self, train_df: DataFrame, valid_df: DataFrame, hparams: TrainingArgs, image_dir: PathType, batch_size: int, num_workers: int = 4, use_weighted_sampler: bool = True, **kwargs ): super().__init__() self.logger = logging.getLogger(__name__) self.image_dir = image_dir self.image_size = hparams.image_size self.crop_size = hparams.crop_size self.train_df = train_df self.valid_df = valid_df self.hparams = hparams self.tpu_cores = hparams.tpu_cores self.batch_size = batch_size self.num_workers = num_workers self.use_weighted_sampler = use_weighted_sampler self.limit_samples_to_draw = hparams.limit_samples_to_draw self.replacement = hparams.replacement if self.limit_samples_to_draw and self.use_weighted_sampler: n_uniq_classes = self.train_df.landmark_id.nunique() # n_samples = int(n_uniq_classes * self.train_df.landmark_id.value_counts().mean()) n_samples = min(self.limit_samples_to_draw, len(self.train_df)) else: n_samples = len(self.train_df) self.sampler = None if self.use_weighted_sampler: self.logger.info(f'Using weighted sampler with total {n_samples} to draw. Replacement: {self.replacement}') imbalanced_sampler = get_imbalanced_sampler(self.train_df.landmark_id, num_samples=n_samples, replacement=self.replacement) if self.tpu_cores is not None and self.tpu_cores > 1: import torch_xla.core.xla_model as xm distributed_sampler = DistributedSamplerWrapper( imbalanced_sampler, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=self.hparams.shuffle ) self.sampler = distributed_sampler else: self.sampler = imbalanced_sampler self.collate_fn = None self.train_dataset = None self.valid_dataset = None
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader # don't manipulate iterable datasets is_dataloader = isinstance(dataloader, DataLoader) is_iterable_ds = False if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'): is_iterable_ds = isinstance(dataloader.dataset, IterableDataset) if not is_dataloader or is_iterable_ds: return dataloader need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu) if self.replace_sampler_ddp and need_dist_sampler: skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] dl_args = { k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys } if self.use_tpu: sampler = DistributedSampler( dataloader.dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), ) elif self.use_horovod: sampler = DistributedSampler(dataloader.dataset, num_replicas=hvd.size(), rank=hvd.rank()) else: world_size = { 'ddp': self.num_nodes * self.num_processes, 'ddp2': self.num_nodes, 'ddp_cpu': self.num_processes * self.num_nodes } sampler = DistributedSampler( dataloader.dataset, num_replicas=world_size.get(self.distributed_backend, 0), rank=self.proc_rank, ) dl_args['sampler'] = sampler dataloader = type(dataloader)(**dl_args) return dataloader
def prepare_task(args, xla_device): # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build models and criteria to print some metadata torch.manual_seed(args.seed) model, criterion = task.build_model(args), task.build_criterion(args) xm.master_print(model) xm.master_print('| model {}, criterion {}'.format( args.arch, criterion.__class__.__name__)) xm.master_print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad))) model = model.to(xla_device) trainer = Trainer(args, task, model, criterion, xla_device=xla_device) lr = trainer.get_lr() # Load the latest checkpoint if one is available and restore the # corresponding train iterator # we overwrite distributed args here to shard data using torch_xla's # distributed training. trainer.args.distributed_rank = xm.get_ordinal() trainer.args.distributed_world_size = xm.xrt_world_size() extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) trainer.args.distributed_rank = 0 trainer.args.distributed_world_size = 1 trainer.meters_to_device(xla_device) valid_subsets = args.valid_subset.split(',') ordinal = xm.get_ordinal(defval=-1) device_str = ( str(xla_device) if ordinal < 0 else '{}/{}'.format(xla_device, ordinal) ) return task, trainer, model, epoch_itr, lr, valid_subsets, device_str
def _get_eval_sampler( self, eval_dataset: Dataset ) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(eval_dataset, torch.utils.data.IterableDataset): return None elif is_torch_tpu_available(): return SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.args.local_rank != -1: return SequentialDistributedSampler(eval_dataset) else: return SequentialSampler(eval_dataset)
def _test_idist_methods_in_xla_context_in_child_proc(index): # We explicitly set _model as _SerialModel # then call idist.* methods and check that they give correct values from ignite.distributed.utils import _SerialModel, _set_model _set_model(_SerialModel()) import torch_xla.core.xla_model as xm _test_distrib_config(local_rank=index, backend="xla-tpu", ws=xm.xrt_world_size(), true_device="xla", rank=xm.get_ordinal())
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data, x) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS['batch_size']) if x % FLAGS['log_steps'] == 0: print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True)
def _get_distributed_sampler(self, dataloader): if self.use_tpu: kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.use_horovod: kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) else: world_size = { 'ddp': self.num_nodes * self.num_processes, 'ddp2': self.num_nodes, 'ddp_cpu': self.num_processes * self.num_nodes } kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank) sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() total_samples_train, correct_train = 0, 0 # Training and calculating train accuracy and loss for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) train_loss = loss_fn(output, target) train_loss.backward() xm.optimizer_step(optimizer) tracker.add(data.shape[0]) pred_train = output.max(1, keepdim=True)[1] correct_train += pred_train.eq(target.view_as(pred_train)).sum().item() total_samples_train += data.size()[0] scheduler.step() if x % 40 == 0: print( "[xla:{}]({})\tLoss={:.3f}\tRate={:.2f}\tGlobalRate={:.2f}".format( xm.get_ordinal(), x, train_loss.item(), tracker.rate(), tracker.global_rate(), ), flush=True, ) train_accuracy = 100.0 * correct_train / total_samples_train print( "[xla:{}] Accuracy={:.2f}%".format(xm.get_ordinal(), train_accuracy), flush=True, ) return train_accuracy
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None ) -> DataLoader: if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset if self.args.use_bucket_iterator: bucket_boundaries = [0, 20, 30, 40, 50, 60, 70, 80, 90, 101] eval_sampler = BySequenceLengthSampler( eval_dataset, bucket_boundaries, batch_size=self.args.eval_batch_size, drop_last=False) data_loader = DataLoader( eval_dataset, batch_size=1, batch_sampler=eval_sampler, collate_fn=self.data_collator.collate_batch, num_workers=0, pin_memory=False) else: if is_tpu_available(): sampler = SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.args.local_rank != -1: sampler = SequentialDistributedSampler(eval_dataset) else: sampler = SequentialSampler(eval_dataset) data_loader = DataLoader( eval_dataset, sampler=sampler, batch_size=self.args.eval_batch_size, collate_fn=self.data_collator.collate_batch, ) if is_tpu_available(): data_loader = pl.ParallelLoader( data_loader, [self.args.device]).per_device_loader(self.args.device) return data_loader
def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples print('[xla:{}] Accuracy={:.2f}%'.format( xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target
def dummy_reduce_scatter(reduce_type, input, scale, scatter_dim, shard_count, groups=None): """A dummy op for debugging with the same output shape as reduce_scatter""" assert shard_count == xm.xrt_world_size() full_size = input.size(scatter_dim) shard_size = full_size // xm.xrt_world_size() begin = shard_size * xm.get_ordinal() end = begin + shard_size slices = [None] * input.dim() slices[scatter_dim] = slice(begin, end) return input[tuple(slices)] * scale
def _mp_fn(rank, flags, model,serial): global WRAPPED_MODEL, FLAGS, SERIAL_EXEC WRAPPED_MODEL = model FLAGS = flags SERIAL_EXEC = serial accuracy_valid = main(args.config_file) # Retrieve tensors that are on TPU core 0 and plot. # plot_results(data.cpu(), pred.cpu(), target.cpu()) xm.master_print(('DONE', accuracy_valid)) # 4. Save model if xm.get_ordinal() == 0: WRAPPED_MODEL.to('cpu') torch.save(WRAPPED_MODEL.state_dict(), os.path.join(config.model_path, 'model.bin')) xm.master_print('saved model.')
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, batch in enumerate(loader): # batch = tuple(t.to(self.device) for t in batch) loss = model(*batch) # the last one is label #loss = criterion(output, batch[-1]) loss.backward() # xm.optimizer_step(optimizer) # optimizer.zero_grad() tracker.add(FLAGS.batch_size) if (x + 1) % config.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_( model.parameters(), config.max_grad_norm) # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。 xm.optimizer_step(optimizer) optimizer.zero_grad() if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True)
def setup(self, trainer: "pl.Trainer") -> None: shared_params = find_shared_parameters(self.model) self.model_to_device() if is_overridden("on_post_move_to_device", self.lightning_module): self.model.on_post_move_to_device() else: set_shared_parameters(self.model, shared_params) super().setup(trainer) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal()