Esempio n. 1
0
def _mp_fn(index):
    ordinal = xm.get_ordinal()
    print('Core {} waiting for rendezvous ...'.format(ordinal))
    data = xmp.rendezvous('rendezvous_test', 'ORD={}'.format(ordinal))
    print('Core {} got rendezvous!'.format(ordinal))
    for i in range(0, len(data)):
        m = re.match(r'ORD=(\d+)', data[i])
        assert m, 'Bad payload format: {}'.format(data[i])
        xordinal = int(m.group(1))
        assert i == xordinal, 'Payload {} got ordinal {}'.format(i, xordinal)
def is_main_process(local_rank):
    """
    Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
    `local_rank`.
    """
    if is_torch_tpu_available():
        import torch_xla.core.xla_model as xm

        return xm.get_ordinal() == 0
    return local_rank in [-1, 0]
Esempio n. 3
0
def get_rank():
    if is_xla():
        return xm.get_ordinal()
    if not dist.is_available():
        return 0
    if not dist.is_nccl_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank()
Esempio n. 4
0
def distributed_init(args):
    if not getattr(args, 'tpu', False):
        if torch.distributed.is_initialized():
            warnings.warn(
                'Distributed is already initialized, cannot initialize twice!')
        else:
            logger.info('distributed init (rank {}): {}'.format(
                args.distributed_rank,
                args.distributed_init_method,
            ))
            dist.init_process_group(
                backend=args.distributed_backend,
                init_method=args.distributed_init_method,
                world_size=args.distributed_world_size,
                rank=args.distributed_rank,
            )
            logger.info('initialized host {} as rank {}'.format(
                socket.gethostname(),
                args.distributed_rank,
            ))

            # perform a dummy all-reduce to initialize the NCCL communicator
            if torch.cuda.is_available():
                dist.all_reduce(torch.zeros(1).cuda())

        args.distributed_rank = torch.distributed.get_rank()
    else:
        import torch_xla.core.xla_model as xm
        assert xm.xrt_world_size() == args.distributed_world_size
        args.device_id = xm.get_local_ordinal()
        args.distributed_rank = xm.get_ordinal()
        xm.rendezvous('distributed_init')  # wait for all workers
        xm.mark_step()

    if is_master(args):
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(logging.WARNING)

    if args.model_parallel_size > 1:
        try:
            from fairseq.model_parallel.megatron.mpu import (
                get_model_parallel_rank,
                initialize_model_parallel,
                model_parallel_cuda_manual_seed,
            )
        except ImportError:
            raise ImportError('\n\nPlease install the megatron submodule:'
                              '\n\n  git submodule update --init '
                              'fairseq/model_parallel/megatron')
        initialize_model_parallel(args.model_parallel_size)
        model_parallel_cuda_manual_seed(args.seed)
        model_part_number = get_model_parallel_rank()
        args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number)
    return args.distributed_rank
Esempio n. 5
0
 def make_loader(self, dataset, batch_size, phase):
     shuffle = True if phase == "train" else False
     if self.using_tpu:
         sampler = DistributedSampler(dataset=dataset,
                                      num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(),
                                      shuffle=shuffle)
         loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler,
                             num_workers=xm.xrt_world_size(), drop_last=True)
     else:
         loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle)
     return loader
Esempio n. 6
0
def build_dataloader(args, tokenizer):
    train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
    train_sampler = RandomSampler(train_dataset)
    if xm.xrt_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset,
                                           num_replicas=xm.xrt_world_size(),
                                           rank=xm.get_ordinal(),
                                           shuffle=True)
    return DataLoader(train_dataset,
                      sampler=train_sampler,
                      batch_size=args.train_batch_size)
 def _wrap_dl(self, dl):
     if isinstance(dl, pl.PerDeviceLoader):
         return dl
     else:
         #dl = dl.to(self.device)
         dl.fake_l.num_workers = 0  # For some reason, needed for it to work (something on fastai's end). Need to investigate further
         distributed_dl = TPUDistributedDL(
             dl, xm.get_ordinal(),
             xm.xrt_world_size())  # Use existing distributed functionality
         return pl.ParallelLoader(
             distributed_dl, [self.device]).per_device_loader(self.device)
Esempio n. 8
0
 def process_index(self):
     """
     The number of processes used in parallel.
     """
     if is_torch_tpu_available():
         return xm.get_ordinal()
     elif is_sagemaker_distributed_available():
         return sm_dist.get_rank()
     elif self.local_rank != -1:
         return torch.distributed.get_rank()
     return 0
Esempio n. 9
0
 def _train_xla(self, rank, loader, loader_valid, num_epochs):
     seed_everything(self.random_state, self.deterministic)
     self.device = xm.xla_device()
     self.rank = xm.get_ordinal()
     self._configure_model()
     loader = self._configure_loader_ddp(loader)
     loader_valid = self._configure_loader_ddp(loader_valid, shuffle=False)
     self._train(
         loader.per_device_loader(self.device),
         loader_valid.per_device_loader(self.device),
         num_epochs)
Esempio n. 10
0
 def test_loop_fn(loader):
     model.eval()
     for x, (data, label) in enumerate(loader):
         output = model(image=data,
                        label=label,
                        get_embedding=args.get_embeddings)
         loss = loss_fn(output, label)
         if x % 20 == 0:
             print('[xla:{}]({}) Loss={:.5f}'.format(
                 xm.get_ordinal(), x, loss.item()),
                   flush=True)
Esempio n. 11
0
    def auto_add_sampler(self, dataloader, train):
        # do nothing when user gives a sampler
        dl_args = {
            'dataset': dataloader.dataset,
            'batch_size': dataloader.batch_size,
            'shuffle': False,
            'num_workers': dataloader.num_workers,
            'collate_fn': dataloader.collate_fn,
            'pin_memory': dataloader.pin_memory,
            'drop_last': dataloader.drop_last,
            'timeout': dataloader.timeout,
            'worker_init_fn': dataloader.worker_init_fn
        }

        if train:
            if self.use_ddp or self.use_ddp2:
                sampler = DistributedSampler(dataloader.dataset)
                dl_args['shuffle'] = False

            elif self.use_tpu:
                sampler = DistributedSampler(dataloader.dataset,
                                             num_replicas=xm.xrt_world_size(),
                                             rank=xm.get_ordinal())
                dl_args['shuffle'] = False
            else:
                sampler = RandomSampler(dataloader.dataset)

        # on not train
        else:
            if self.use_tpu:
                sampler = DistributedSampler(dataloader.dataset,
                                             num_replicas=xm.xrt_world_size(),
                                             rank=xm.get_ordinal())
                dl_args['shuffle'] = False
            else:
                sampler = SequentialSampler(dataloader.dataset)

        dl_args['sampler'] = sampler

        new_dataloader = DataLoader(**dl_args)
        return new_dataloader
Esempio n. 12
0
def get_dataloader(path, batch_size, sequence_length, num_workers):
  dataset = LazyDataset(path, sequence_length + 1)
  if xm.xrt_world_size() > 1:
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
  else:
    sampler = torch.utils.data.RandomSampler(dataset)
  return torch.utils.data.DataLoader(
      dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)
Esempio n. 13
0
 def process_index(self):
     """
     The index of the current process used.
     """
     if is_torch_tpu_available():
         return xm.get_ordinal()
     elif is_sagemaker_mp_enabled():
         return smp.dp_rank()
     elif is_sagemaker_dp_enabled():
         return sm_dist.get_rank()
     elif self.local_rank != -1:
         return torch.distributed.get_rank()
     return 0
Esempio n. 14
0
 def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
     if is_torch_tpu_available() and xm.xrt_world_size() > 1:
         num_replicas = xm.xrt_world_size()
         rank = xm.get_ordinal()
     elif self.args.local_rank != -1:
         num_replicas = torch.distributed.get_world_size()
         rank = self.args.local_rank
     else:
         num_replicas = 1
         rank = 0
     return MultiTaskBatchSampler(self.dataset_sizes, self.args.train_batch_size,
                                  self.args.temperature, rank=rank,
                                  num_replicas=num_replicas)
Esempio n. 15
0
def get_rank() -> int:
    """
    Returns the rank of the current worker.

    Returns:
        int: ``rank`` if torch.distributed is initialized, otherwise ``-1``
    """
    if _is_xla_distributed_initialized():
        return xm.get_ordinal()
    elif _is_torch_distributed_initialized():
        return dist.get_rank()
    else:
        return -1
Esempio n. 16
0
 def local_process_index(self):
     """
     The index of the local process used.
     """
     if is_torch_tpu_available():
         return xm.get_ordinal(local=True)
     elif is_sagemaker_mp_enabled():
         return smp.local_rank()
     elif is_sagemaker_dp_enabled():
         return sm_dist.get_rank()
     elif self.local_rank != -1:
         return self.local_rank
     return 0
Esempio n. 17
0
    def __init__(self,
                 train_df: DataFrame, valid_df: DataFrame,
                 hparams: TrainingArgs,
                 image_dir: PathType,
                 batch_size: int,
                 num_workers: int = 4,
                 use_weighted_sampler: bool = True,
                 **kwargs
                 ):
        super().__init__()
        self.logger = logging.getLogger(__name__)
        self.image_dir = image_dir
        self.image_size = hparams.image_size
        self.crop_size = hparams.crop_size
        self.train_df = train_df
        self.valid_df = valid_df
        self.hparams = hparams
        self.tpu_cores = hparams.tpu_cores
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.use_weighted_sampler = use_weighted_sampler
        self.limit_samples_to_draw = hparams.limit_samples_to_draw
        self.replacement = hparams.replacement

        if self.limit_samples_to_draw and self.use_weighted_sampler:
            n_uniq_classes = self.train_df.landmark_id.nunique()
            # n_samples = int(n_uniq_classes * self.train_df.landmark_id.value_counts().mean())
            n_samples = min(self.limit_samples_to_draw, len(self.train_df))
        else:
            n_samples = len(self.train_df)

        self.sampler = None
        if self.use_weighted_sampler:
            self.logger.info(f'Using weighted sampler with total {n_samples} to draw. Replacement: {self.replacement}')
            imbalanced_sampler = get_imbalanced_sampler(self.train_df.landmark_id, num_samples=n_samples,
                                                        replacement=self.replacement)
            if self.tpu_cores is not None and self.tpu_cores > 1:
                import torch_xla.core.xla_model as xm
                distributed_sampler = DistributedSamplerWrapper(
                    imbalanced_sampler,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                    shuffle=self.hparams.shuffle
                )
                self.sampler = distributed_sampler
            else:
                self.sampler = imbalanced_sampler

        self.collate_fn = None
        self.train_dataset = None
        self.valid_dataset = None
    def auto_add_sampler(self, dataloader: DataLoader,
                         train: bool) -> DataLoader:

        # don't do anything if it's not a dataloader
        # don't manipulate iterable datasets
        is_dataloader = isinstance(dataloader, DataLoader)

        is_iterable_ds = False
        if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'):
            is_iterable_ds = isinstance(dataloader.dataset, IterableDataset)

        if not is_dataloader or is_iterable_ds:
            return dataloader
        need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod
                             or self.use_tpu)
        if self.replace_sampler_ddp and need_dist_sampler:

            skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']

            dl_args = {
                k: v
                for k, v in dataloader.__dict__.items()
                if not k.startswith('_') and k not in skip_keys
            }

            if self.use_tpu:
                sampler = DistributedSampler(
                    dataloader.dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                )
            elif self.use_horovod:
                sampler = DistributedSampler(dataloader.dataset,
                                             num_replicas=hvd.size(),
                                             rank=hvd.rank())
            else:
                world_size = {
                    'ddp': self.num_nodes * self.num_processes,
                    'ddp2': self.num_nodes,
                    'ddp_cpu': self.num_processes * self.num_nodes
                }
                sampler = DistributedSampler(
                    dataloader.dataset,
                    num_replicas=world_size.get(self.distributed_backend, 0),
                    rank=self.proc_rank,
                )

            dl_args['sampler'] = sampler
            dataloader = type(dataloader)(**dl_args)

        return dataloader
Esempio n. 19
0
    def prepare_task(args, xla_device):
        # Setup task, e.g., translation, language modeling, etc.
        task = tasks.setup_task(args)

        # Load valid dataset (we load training data below, based on the latest checkpoint)
        for valid_sub_split in args.valid_subset.split(','):
            task.load_dataset(valid_sub_split, combine=True, epoch=0)

        # Build models and criteria to print some metadata
        torch.manual_seed(args.seed)
        model, criterion = task.build_model(args), task.build_criterion(args)
        xm.master_print(model)
        xm.master_print('| model {}, criterion {}'.format(
            args.arch, criterion.__class__.__name__))
        xm.master_print('| num. model params: {} (num. trained: {})'.format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad)))
        model = model.to(xla_device)
        trainer = Trainer(args, task, model, criterion, xla_device=xla_device)
        lr = trainer.get_lr()

        # Load the latest checkpoint if one is available and restore the
        # corresponding train iterator
        # we overwrite distributed args here to shard data using torch_xla's
        # distributed training.
        trainer.args.distributed_rank = xm.get_ordinal()
        trainer.args.distributed_world_size = xm.xrt_world_size()
        extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
        trainer.args.distributed_rank = 0
        trainer.args.distributed_world_size = 1
        trainer.meters_to_device(xla_device)
        valid_subsets = args.valid_subset.split(',')
        ordinal = xm.get_ordinal(defval=-1)
        device_str = (
            str(xla_device) if ordinal < 0 else
            '{}/{}'.format(xla_device, ordinal)
        )
        return task, trainer, model, epoch_itr, lr, valid_subsets, device_str
Esempio n. 20
0
 def _get_eval_sampler(
         self, eval_dataset: Dataset
 ) -> Optional[torch.utils.data.sampler.Sampler]:
     if isinstance(eval_dataset, torch.utils.data.IterableDataset):
         return None
     elif is_torch_tpu_available():
         return SequentialDistributedSampler(
             eval_dataset,
             num_replicas=xm.xrt_world_size(),
             rank=xm.get_ordinal())
     elif self.args.local_rank != -1:
         return SequentialDistributedSampler(eval_dataset)
     else:
         return SequentialSampler(eval_dataset)
Esempio n. 21
0
def _test_idist_methods_in_xla_context_in_child_proc(index):
    # We explicitly set _model as _SerialModel
    # then call idist.* methods and check that they give correct values
    from ignite.distributed.utils import _SerialModel, _set_model

    _set_model(_SerialModel())

    import torch_xla.core.xla_model as xm

    _test_distrib_config(local_rank=index,
                         backend="xla-tpu",
                         ws=xm.xrt_world_size(),
                         true_device="xla",
                         rank=xm.get_ordinal())
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for x, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data, x)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS['batch_size'])
         if x % FLAGS['log_steps'] == 0:
             print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                 xm.get_ordinal(), x, loss.item(), tracker.rate(),
                 tracker.global_rate(), time.asctime()), flush=True)
Esempio n. 23
0
 def _get_distributed_sampler(self, dataloader):
     if self.use_tpu:
         kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
     elif self.use_horovod:
         kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
     else:
         world_size = {
             'ddp': self.num_nodes * self.num_processes,
             'ddp2': self.num_nodes,
             'ddp_cpu': self.num_processes * self.num_nodes
         }
         kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank)
     sampler = DistributedSampler(dataloader.dataset, **kwargs)
     return sampler
Esempio n. 24
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        total_samples_train, correct_train = 0, 0

        # Training and calculating train accuracy and loss
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            train_loss = loss_fn(output, target)
            train_loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(data.shape[0])

            pred_train = output.max(1, keepdim=True)[1]
            correct_train += pred_train.eq(target.view_as(pred_train)).sum().item()
            total_samples_train += data.size()[0]

            scheduler.step()
            if x % 40 == 0:
                print(
                    "[xla:{}]({})\tLoss={:.3f}\tRate={:.2f}\tGlobalRate={:.2f}".format(
                        xm.get_ordinal(),
                        x,
                        train_loss.item(),
                        tracker.rate(),
                        tracker.global_rate(),
                    ),
                    flush=True,
                )

        train_accuracy = 100.0 * correct_train / total_samples_train
        print(
            "[xla:{}] Accuracy={:.2f}%".format(xm.get_ordinal(), train_accuracy),
            flush=True,
        )
        return train_accuracy
Esempio n. 25
0
    def get_eval_dataloader(self,
                            eval_dataset: Optional[Dataset] = None
                            ) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if self.args.use_bucket_iterator:

            bucket_boundaries = [0, 20, 30, 40, 50, 60, 70, 80, 90, 101]
            eval_sampler = BySequenceLengthSampler(
                eval_dataset,
                bucket_boundaries,
                batch_size=self.args.eval_batch_size,
                drop_last=False)

            data_loader = DataLoader(
                eval_dataset,
                batch_size=1,
                batch_sampler=eval_sampler,
                collate_fn=self.data_collator.collate_batch,
                num_workers=0,
                pin_memory=False)
        else:

            if is_tpu_available():
                sampler = SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal())
            elif self.args.local_rank != -1:
                sampler = SequentialDistributedSampler(eval_dataset)
            else:
                sampler = SequentialSampler(eval_dataset)

            data_loader = DataLoader(
                eval_dataset,
                sampler=sampler,
                batch_size=self.args.eval_batch_size,
                collate_fn=self.data_collator.collate_batch,
            )

        if is_tpu_available():
            data_loader = pl.ParallelLoader(
                data_loader,
                [self.args.device]).per_device_loader(self.args.device)

        return data_loader
Esempio n. 26
0
    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        data, pred, target = None, None, None
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        print('[xla:{}] Accuracy={:.2f}%'.format(
            xm.get_ordinal(), accuracy), flush=True)
        return accuracy, data, pred, target
Esempio n. 27
0
File: utils.py Progetto: pytorch/xla
def dummy_reduce_scatter(reduce_type,
                         input,
                         scale,
                         scatter_dim,
                         shard_count,
                         groups=None):
    """A dummy op for debugging with the same output shape as reduce_scatter"""
    assert shard_count == xm.xrt_world_size()
    full_size = input.size(scatter_dim)
    shard_size = full_size // xm.xrt_world_size()
    begin = shard_size * xm.get_ordinal()
    end = begin + shard_size
    slices = [None] * input.dim()
    slices[scatter_dim] = slice(begin, end)
    return input[tuple(slices)] * scale
Esempio n. 28
0
def _mp_fn(rank, flags, model,serial):
    global WRAPPED_MODEL, FLAGS, SERIAL_EXEC
    WRAPPED_MODEL = model
    FLAGS = flags
    SERIAL_EXEC = serial

    accuracy_valid = main(args.config_file)
    # Retrieve tensors that are on TPU core 0 and plot.
    # plot_results(data.cpu(), pred.cpu(), target.cpu())
    xm.master_print(('DONE',  accuracy_valid))
    # 4. Save model
    if xm.get_ordinal() == 0:
        WRAPPED_MODEL.to('cpu')
        torch.save(WRAPPED_MODEL.state_dict(), os.path.join(config.model_path, 'model.bin'))
        xm.master_print('saved model.')
Esempio n. 29
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, batch in enumerate(loader):
            # batch = tuple(t.to(self.device) for t in batch)
            loss = model(*batch)  # the last one is label
            #loss = criterion(output, batch[-1])
            loss.backward()
            # xm.optimizer_step(optimizer)
            # optimizer.zero_grad()

            tracker.add(FLAGS.batch_size)
            if (x + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.max_grad_norm)
                # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。
                xm.optimizer_step(optimizer)
                optimizer.zero_grad()

            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                        xm.get_ordinal(), x, loss.item(), tracker.rate(),
                        tracker.global_rate(), time.asctime()), flush=True)
Esempio n. 30
0
    def setup(self, trainer: "pl.Trainer") -> None:
        shared_params = find_shared_parameters(self.model)
        self.model_to_device()
        if is_overridden("on_post_move_to_device", self.lightning_module):
            self.model.on_post_move_to_device()
        else:
            set_shared_parameters(self.model, shared_params)

        super().setup(trainer)

        if self.debug:
            os.environ["PT_XLA_DEBUG"] = str(1)

        self.tpu_local_core_rank = xm.get_local_ordinal()
        self.tpu_global_core_rank = xm.get_ordinal()