示例#1
0
    def barrier(self) -> None:
        """
        Synchronizes all processes.

        This collective blocks processes until the all runs enter the function.
        """
        xm.rendezvous("barrier")
示例#2
0
文件: train.py 项目: molamk/aitextgen
    def on_batch_end(self, trainer, pl_module):
        super().on_batch_end(trainer, pl_module)

        # clean up the GPU cache used for the benchmark
        # https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/4
        if self.steps == 0 and self.gpu:
            torch.cuda.empty_cache()

        current_loss = float(trainer.progress_bar_dict["loss"])
        self.steps += 1
        avg_loss = 0
        if current_loss == current_loss:  # don't add if current_loss is NaN
            avg_loss = self.average_loss(current_loss, self.prev_avg_loss,
                                         self.smoothing)
            self.prev_avg_loss = avg_loss

        desc = f"Loss: {current_loss:.3f} — Avg: {avg_loss:.3f}"

        if self.steps % self.progress_bar_refresh_rate == 0:
            if self.gpu:
                desc += f" — GPU Mem: {list(get_gpu_memory_map().values())[0]} MB"
            self.main_progress_bar.update(self.progress_bar_refresh_rate)
            self.main_progress_bar.set_description(desc)

        if self.enabled:

            if self.save_every > 0 and self.steps % self.save_every == 0:
                if pl_module.hparams["tpu"]:
                    xm.rendezvous("save_model")
                self.save_pytorch_model(trainer, pl_module)

            if (not pl_module.hparams["tpu"] and self.generate_every > 0
                    and self.steps % self.generate_every == 0):
                self.generate_sample_text(trainer, pl_module)
示例#3
0
def save(data, path, master_only=True, global_master=False):
    """Saves the input data into a file.

  The saved data is transferred to PyTorch CPU device before being saved, so a
  following `torch.load()` will load CPU data.
  Care must be taken when working with views. Instead of saving views it's
  recommended that you recreate them after the tensors have been loaded and
  moved to their destination device(s).

  Args:
    data: The input data to be saved. Any nested combination of Python objects
      (list, tuples, sets, dicts, ...).
    path: The destination file for the data saving operation. If `master_only`
      is ``False`` the path must point to different destinations as otherwise
      all the writes from the same host will override each other.
    master_only (bool, optional): Whether only the master device should save the
      data. If False, the `path` argument should be a different path for each of
      the ordinals taking part to the replication, otherwise all the replicas on
      the same host will be writing to the same location.
      Default: True
    global_master (bool, optional): When ``master_only`` is ``True`` this flag
      controls whether every host's master (if ``global_master`` is ``False``)
      saves the content, or only the global master (ordinal 0).
      Default: False
  """
    should_write_data = not master_only or xm.is_master_ordinal(
        local=not global_master)

    ref_data = _rewrite_data(_get_tensors_folder(path), data,
                             should_write_data)
    if should_write_data:
        torch.save(ref_data, path)
    xm.rendezvous('torch_xla.utils.serialization.save')
 def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
     should_stop = torch.tensor(int(should_stop),
                                device=self.lightning_module.device)
     stop = xm.mesh_reduce('stop_signal', should_stop, sum)
     rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
     should_stop = int(stop.item()) == self.world_size
     return should_stop
def distributed_init(args):
    if args.distributed_world_size == 1:
        raise ValueError(
            'Cannot initialize distributed with distributed_world_size=1')

    if not getattr(args, 'tpu', False):
        if torch.distributed.is_initialized():
            warnings.warn(
                'Distributed is already initialized, cannot initialize twice!')
        else:
            logger.info('distributed init (rank {}): {}'.format(
                args.distributed_rank,
                args.distributed_init_method,
            ))
            dist.init_process_group(
                backend=args.distributed_backend,
                init_method=args.distributed_init_method,
                world_size=args.distributed_world_size,
                rank=args.distributed_rank,
            )
            logger.info('initialized host {} as rank {}'.format(
                socket.gethostname(),
                args.distributed_rank,
            ))

            # perform a dummy all-reduce to initialize the NCCL communicator
            if torch.cuda.is_available():
                dist.all_reduce(torch.zeros(1).cuda())

        args.distributed_rank = torch.distributed.get_rank()
    else:
        import torch_xla.core.xla_model as xm
        assert xm.xrt_world_size() == args.distributed_world_size
        args.device_id = xm.get_local_ordinal()
        args.distributed_rank = xm.get_ordinal()
        xm.rendezvous('distributed_init')  # wait for all workers
        xm.mark_step()

    if is_master(args):
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(logging.WARNING)

    if args.model_parallel_size > 1:
        try:
            from fairseq.model_parallel.megatron.mpu import (
                get_model_parallel_rank,
                initialize_model_parallel,
                model_parallel_cuda_manual_seed,
            )
        except ImportError:
            raise ImportError('\n\nPlease install the megatron submodule:'
                              '\n\n  git submodule update --init '
                              'fairseq/model_parallel/megatron')
        initialize_model_parallel(args.model_parallel_size)
        model_parallel_cuda_manual_seed(args.seed)
        model_part_number = get_model_parallel_rank()
        args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number)
    return args.distributed_rank
示例#6
0
    def valid_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        if self.tpu:
            import torch_xla.core.xla_model as xm

            xm.rendezvous("valid_step")  # wait for all workers
            xm.mark_step()

        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()

            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                if self._dummy_batch == "DUMMY":
                    self._dummy_batch = sample
                is_dummy_batch = False

            try:
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if not raise_oom:
                        logger.warning(
                            "ran out of memory in validation step, retrying batch"
                        )
                        for p in self.model.parameters():
                            if p.grad is not None:
                                p.grad = None  # free some memory
                        if self.cuda:
                            torch.cuda.empty_cache()
                        return self.valid_step(sample, raise_oom=True)
                raise e

            logging_outputs = [logging_output]
            if is_dummy_batch:
                if torch.is_tensor(sample_size):
                    sample_size.zero_()
                else:
                    sample_size *= 0.0

        # gather logging outputs from all replicas
        if self.data_parallel_world_size > 1:
            logging_outputs, (sample_size, ) = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ignore=is_dummy_batch,
            )

        # log validation stats
        logging_output = self._reduce_and_log_stats(logging_outputs,
                                                    sample_size)

        return logging_output
 def run_validation(e: Engine):
     if distributed:
         torch.cuda.synchronize(device)
     if use_tpu:
         xm.rendezvous('validate_{}'.format(e.state.iteration))
         valid_it = valid_dl.per_device_loader(device)
         evaluator.run(valid_it, epoch_length=len(valid_dl))
     else:
         evaluator.run(valid_dl)
示例#8
0
def tpu_data_loader(args, itr):
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    xm.rendezvous('tpu_data_loader')  # wait for all workers
    xm.mark_step()
    device = utils.get_tpu_device(args)
    return iterators.CountingIterator(
        pl.ParallelLoader(itr, [device]).per_device_loader(device),
        start=getattr(itr, 'n', 0),
        total=len(itr),
    )
示例#9
0
def xla_run_method(rank, fit_method, learner_args, add_args, fit_args,
                   ctrl_args):
    "run fit method on spawned process"
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args,
                                     ctrl_args)
    fit_args = setup_fit_cbs(rank, fit_args)
    fit_method(learner, **fit_args)
    xm.rendezvous('xla_run_method')
    learner.save('_xla_tmp_model', rendezvous=False)
    xm.mark_step()
示例#10
0
def wait_for_everyone():
    """
    Introduces a blocking point in the script, making sure all processes have reached this point before continuing.

    Warning::

        Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
    """
    if AcceleratorState().distributed_type == DistributedType.MULTI_GPU:
        torch.distributed.barrier()
    elif AcceleratorState().distributed_type == DistributedType.TPU:
        xm.rendezvous("accelerate.utils.wait_for_everyone")
示例#11
0
    def forward(self,
                z,
                gy,
                x=None,
                dy=None,
                train_G=False,
                return_G_z=False,
                split_D=False):
        # If training G, enable grad tape
        with torch.set_grad_enabled(train_G):
            # Get Generator output given noise
            G_z = self.G(z, self.G.shared(gy))
            # Cast as necessary
            if self.G.fp16 and not self.D.fp16:
                G_z = G_z.float()
            if self.D.fp16 and not self.G.fp16:
                G_z = G_z.half()
        # Split_D means to run D once with real data and once with fake,
        # rather than concatenating along the batch dimension.
        if split_D:
            D_fake = self.D(G_z, gy)
            if x is not None:
                D_real = self.D(x, dy)
                return D_fake, D_real
            else:
                if return_G_z:
                    return D_fake, G_z
                else:
                    return D_fake
        # If real data is provided, concatenate it with the Generator's output
        # along the batch dimension for improved efficiency.
        else:
            D_input = torch.cat([G_z, x], 0) if x is not None else G_z
            self.counter += 1
            if xm.is_master_ordinal() and (self.counter % 50 == 0):
                torchvision.utils.save_image(D_input[:64].float().cpu(),
                                             'd_inp.png',
                                             nrow=int(D_input.shape[0]**0.5),
                                             normalize=True)
            xm.rendezvous('cont....')
            D_class = torch.cat([gy, dy], 0) if dy is not None else gy
            # Get Discriminator output
            D_out = self.D(D_input, D_class)

            if x is not None:
                return torch.split(
                    D_out, [G_z.shape[0], x.shape[0]])  # D_fake, D_real
            else:
                if return_G_z:
                    return D_out, G_z
                else:
                    return D_out
示例#12
0
def tpu_data_loader(itr):
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    from fairseq.data import iterators

    xm.rendezvous("tpu_data_loader")  # wait for all workers
    xm.mark_step()
    device = xm.xla_device()
    return iterators.CountingIterator(
        pl.ParallelLoader(itr, [device]).per_device_loader(device),
        start=getattr(itr, "n", 0),
        total=len(itr),
    )
示例#13
0
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
    if not xm.is_master_ordinal():
        xm.rendezvous("load_and_cache_examples")

    processor = processors[task]()
    output_mode = output_modes[task]
    cached_features_file = os.path.join(
        args.cache_dir,
        "cached_{}_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
            str(task),
        ),
    )

    # Load data features from cache or dataset file
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
        if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta"]:
            # HACK(label indices are swapped in RoBERTa pretrained model)
            label_list[1], label_list[2] = label_list[2], label_list[1]
        examples = (
            processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
        )
        features = convert_examples_to_features(
            examples, tokenizer, max_length=args.max_seq_length, label_list=label_list, output_mode=output_mode,
        )
        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)

    if xm.is_master_ordinal():
        xm.rendezvous("load_and_cache_examples")

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    return dataset
示例#14
0
    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)
示例#15
0
    def begin_epoch(self, epoch):
        """Called at the beginning of each epoch."""
        logger.info("begin training epoch {}".format(epoch))

        if self.quantizer is not None:
            self.quantizer.begin_epoch(epoch)

        # task specific setup per epoch
        self.task.begin_epoch(epoch, self.get_model())

        if self.tpu:
            import torch_xla.core.xla_model as xm

            xm.rendezvous('begin_epoch')  # wait for all workers
            xm.mark_step()
示例#16
0
def synchronize(message="sync-workers"):
    if is_xla():
        xm.rendezvous(message)
    elif not dist.is_available():
        return
    if not dist.is_nccl_available():
        return
    if not dist.is_initialized():
        return

    world_size = dist.get_world_size()

    if world_size == 1:
        return

    dist.barrier()
def _mp_fn(index):
  ordinal = xm.get_ordinal()
  print('Core {} waiting for rendezvous ...'.format(ordinal))
  replicas, gid = _get_replica_group(index)
  data = xm.rendezvous(
      'rendezvous_test.{}'.format(gid),
      'ORD={}'.format(ordinal).encode('utf-8'),
      replicas=replicas)
  print('Core {} got rendezvous!'.format(ordinal))
  for i in range(0, len(data)):
    idata = data[i].decode('utf-8')
    m = re.match(r'ORD=(\d+)', idata)
    assert m, 'Bad payload format: {}'.format(idata)
    xordinal = int(m.group(1))
    assert replicas[i] == xordinal, 'Payload {} got ordinal {}'.format(
        replicas[i], xordinal)
  xm.rendezvous('_mp_fn.exit')
示例#18
0
def simple_map_fn(index, flags):
    # Sets a common random seed - both for initialization and ensuring graph is the same
    torch.manual_seed(1234)

    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    device = xm.xla_device()

    # Creates a tensor on this process's device
    t = torch.randn((2, 2), device=device)

    print("device: ", device, " str: ", str(device), " real: ",
          xm.xla_real_devices([str(device)]))
    print("master: ", xm.is_master_ordinal())
    print("Process", index, "is using", xm.xla_real_devices([str(device)])[0])

    # Barrier to prevent master from exiting before workers connect.
    xm.rendezvous('init')
示例#19
0
def load_cifar_10_xla(batch_size, root='/tmp/cifar10'):
    """Download and load the CIFAR-10 dataset."""
    if not xm.is_master_ordinal():
        # Barrier: Wait until master is done downloading
        xm.rendezvous('download_only_once')

    # Get and shard dataset into dataloaders
    norm = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                                std=(0.2023, 0.1994, 0.2010))
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), norm
    ])
    transform_test = transforms.Compose([transforms.ToTensor(), norm])

    cifar10_train = datasets.CIFAR10(root=root,
                                     train=True,
                                     download=True,
                                     transform=transform_train)
    cifar10_test = datasets.CIFAR10(root=root,
                                    train=False,
                                    download=True,
                                    transform=transform_test)

    if xm.is_master_ordinal():
        # Barrier: Master done downloading, other workers can proceed
        xm.rendezvous('download_only_once')

    train_sampler = DistributedSampler(cifar10_train,
                                       num_replicas=xm.xrt_world_size(),
                                       rank=xm.get_ordinal(),
                                       shuffle=True)
    train_loader = DataLoader(cifar10_train,
                              batch_size=batch_size,
                              sampler=train_sampler,
                              num_workers=4,
                              drop_last=True)
    test_loader = DataLoader(cifar10_test,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=4,
                             drop_last=True)

    return train_loader, test_loader
示例#20
0
def broadcast_xla_master_model_param(model):
    logger.info(
        "Broadcasting XLA model parameters and buffers from master process ..."
    )

    parameters_and_buffers = []
    for p in chain(model.parameters(), model.buffers()):
        # Set all params in non-master devices to zero so that all_reduce is equivalent
        # to broadcasting parameters from master to other devices.
        if not is_main():
            zero = torch.tensor(0, dtype=p.data.dtype, device=p.data.device)
            p.data.mul_(zero)
        parameters_and_buffers.append(p.data)
    xm.wait_device_ops()
    xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers)
    xm.mark_step()
    xm.rendezvous("mmf.trainers.core.device.broadcast_xla_master_model_param")
    logger.info("Done!")
示例#21
0
def xla_run_lr_find(rank, learner_args, add_args, lr_find_args, ctrl_args):
    xm.rendezvous('start_xla_run_lr_find')
    # print(f'xla {rank} : start run lrfind')
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args,
                                     ctrl_args)

    num_it = lr_find_args['num_it']
    n_epoch = num_it // len(learner.dls.train) + 1
    learner.opt = None
    learner.create_opt()
    cb = XLALRFinder(**lr_find_args)

    skip_valid_cb = SkipValidationCallback()
    synced_cancel_cb = SyncedCancelCallback()

    with learner.no_logging():
        learner.fit(n_epoch, cbs=[cb, skip_valid_cb, synced_cancel_cb])
示例#22
0
    def on_batch_end(self, trainer, pl_module):
        super().on_batch_end(trainer, pl_module)

        # clean up the GPU cache used for the benchmark
        # https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/4
        if self.steps == 0 and self.gpu:
            torch.cuda.empty_cache()

        current_loss = float(trainer.progress_bar_dict["loss"])
        self.steps += 1
        avg_loss = 0
        if current_loss == current_loss:  # don't add if current_loss is NaN
            avg_loss = self.average_loss(current_loss, self.prev_avg_loss,
                                         self.smoothing)
            self.prev_avg_loss = avg_loss

        desc = f"Loss: {current_loss:.3f} — Avg: {avg_loss:.3f}"

        # Ajout serge Save loss
        if self.steps % self.save_every == 0 or self.steps == 1:
            print("Enregistrement du loss dans loss.txt")
            fichier = "./loss.txt"
            data = str(self.steps) + " " + desc + "\n"
            with open(fichier, "a") as fd:
                fd.write(data)
            fd.close()

        if self.steps % self.progress_bar_refresh_rate == 0:
            if self.gpu:
                desc += f" — GPU Mem: {get_gpu_memory_map()['gpu_0']} MB"
            self.main_progress_bar.update(self.progress_bar_refresh_rate)
            self.main_progress_bar.set_description(desc)

        if self.enabled:

            if self.save_every > 0 and self.steps % self.save_every == 0:
                if pl_module.hparams["tpu"]:
                    xm.rendezvous("save_model")
                self.save_pytorch_model(trainer, pl_module)

            if (not pl_module.hparams["tpu"] and self.generate_every > 0
                    and self.steps % self.generate_every == 0):
                self.generate_sample_text(trainer, pl_module)
示例#23
0
def _mp_fn(index):
    ordinal = xm.get_ordinal()
    print('Core {} waiting for rendezvous ...'.format(ordinal))
    data = xm.rendezvous('rendezvous_test', 'ORD={}'.format(ordinal))
    print('Core {} got rendezvous!'.format(ordinal))
    for i in range(0, len(data)):
        m = re.match(r'ORD=(\d+)', data[i])
        assert m, 'Bad payload format: {}'.format(data[i])
        xordinal = int(m.group(1))
        assert i == xordinal, 'Payload {} got ordinal {}'.format(i, xordinal)
示例#24
0
def _mp_fn(index, _hps, _dat_file):
    t.manual_seed(_hps.random_seed)

    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    pre_dev_time = time.time()
    device = xm.xla_device()
    device_str = xm.xla_real_devices([str(device)])[0]
    elapsed = time.time() - pre_dev_time
    print(f'process {index} acquired {device_str} in {elapsed} seconds',
          file=stderr,
          flush=True)

    pre_inst_time = time.time()
    m = ch.Chassis(device, index, _hps, _dat_file)
    print(f'Created Chassis in {time.time() - pre_inst_time:3.5} seconds.',
          file=stderr,
          flush=True)
    xm.rendezvous('init')
    m.train()
示例#25
0
def xla_run_inference(rank, learner_args, add_args, inference_args, ctrl_args):
    sync_valid = True
    learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args,
                                     ctrl_args)
    pred_args, master_cbs = setup_inference_args(rank, inference_args)

    if rank == 0 and len(master_cbs) > 0:
        learner.add_cbs(master_cbs)

    # learner.synced_cancel.before_fit()

    if rank == 0:
        learner.sync_recorder.orig_logger = learner.logger

    results = learner.inner_get_preds(**pred_args)
    xm.rendezvous('xla_run_inference')

    save_pred_results(rank, results)
    xm.mark_step()
示例#26
0
def sync_bn1d_no_channel_test(index):
    torch.manual_seed(1)
    bsz = 32
    length = 64
    t_global = torch.rand((xm.xrt_world_size() * bsz, length))

    # XLA SyncBatchNorm
    device = xm.xla_device()
    t_xla = t_global[bsz * index:bsz * (index + 1), ...].to(device)
    sbn_xla = xf.SyncBatchNorm(length).to(device)
    result = run_step(sbn_xla, t_xla)

    # CPU BatchNorm
    bn_cpu = torch.nn.BatchNorm1d(length)
    expected = run_step(bn_cpu, t_global)

    cpu_result = result.cpu()
    assert cpu_result.allclose(expected, rtol=RTOL, atol=ATOL)
    assert_stats(sbn_xla.cpu(), bn_cpu)

    xm.rendezvous('sync_bn1d_no_channel_test')
    xm.master_print('sync_bn1d_no_channel_test ok')
示例#27
0
def sync_bn3d_test(index):
    torch.manual_seed(1)
    bsz = 16
    features = 32
    d, h, w = 16, 32, 32
    t_global = torch.rand((xm.xrt_world_size() * bsz, features, d, h, w))

    # XLA SyncBatchNorm
    device = xm.xla_device()
    t_xla = t_global[bsz * index:bsz * (index + 1), ...].to(device)
    sbn_xla = xf.SyncBatchNorm(features).to(device)
    result = run_step(sbn_xla, t_xla)

    # CPU BatchNorm
    bn_cpu = torch.nn.BatchNorm3d(features)
    expected = run_step(bn_cpu, t_global)

    cpu_result = result.cpu()
    assert cpu_result.allclose(expected, rtol=RTOL, atol=ATOL)
    assert_stats(sbn_xla.cpu(), bn_cpu)

    xm.rendezvous('sync_bn3d_test')
    xm.master_print('sync_bn3d_test ok')
示例#28
0
def save_xla_ckpt(ckpt, file_or_path):
    """
    Similar to xm.save, but only try to convert "model" and "optimizer" in an MMF
    checkpoint to CPU, since they hold PyTorch tensors. Other items like lr_scheduler
    often cannot be saved with xm.save due to its errors in handling mappingproxy.

    Only save on the global main process (which is different from the default behavior
    of xm.save that saves a checkpoint on each node).
    """
    should_write_data = is_main()

    is_full_ckpt = isinstance(ckpt,
                              dict) and "model" in ckpt and "optimizer" in ckpt
    if is_full_ckpt:
        ckpt["model"] = xm._maybe_convert_to_cpu(ckpt["model"],
                                                 convert=should_write_data)
        ckpt["optimizer"] = xm._maybe_convert_to_cpu(ckpt["optimizer"],
                                                     convert=should_write_data)
    else:
        ckpt = xm._maybe_convert_to_cpu(ckpt, convert=should_write_data)

    if should_write_data:
        torch.save(ckpt, file_or_path)
    xm.rendezvous("mmf.utils.checkpoint.save_xla_ckpt")
示例#29
0
def _mp_fn(index, temp_file):
    device = xm.xla_device()
    dd = _create_state_dict(device)
    xm.save(dd, temp_file)
    ldd = torch.load(temp_file)
    pdd = _get_data_str(ldd)
    data = xm.rendezvous('xm_save_test', pdd)
    if xm.get_local_ordinal() == 0:
        os.remove(temp_file)
    for i in range(1, len(data)):
        bio = io.BytesIO(data[i])
        ildd = torch.load(bio)
        for k, v in ldd.items():
            if isinstance(v, torch.Tensor):
                assert v.allclose(ildd[k])
            elif isinstance(v, (list, tuple)):
                iv = ildd[k]
                for a, b in zip(v, iv):
                    assert a.allclose(b)
            else:
                raise RuntimeError('Invalid data type')
示例#30
0
    def on_epoch_end(self, callbacks):
        def execute():
            for callback in callbacks:
                callback.on_epoch_end()

            # reset
            for name, metric in self.train_eval.items():
                self.train_eval[name].reset()
            for name, metric in self.dev_eval.items():
                self.dev_eval[name].reset()

        if self.using_tpu:
            xm.rendezvous("train is done!")
            if xm.is_master_ordinal():
                execute()
                xm.rendezvous("on_epoch_end is done!")
            else:
                xm.rendezvous("on_epoch_end is done!")
        else:
            execute()