예제 #1
0
def xm_save(data, file_or_path, master_only=True, global_master=False, rendezvous=True):
    """Saves the input data into a file.

    The saved data is transferred to PyTorch CPU device before being saved, so a
    following `torch.load()` will load CPU data.
    Care must be taken when working with views. Instead of saving views it's
    recommended that you recreate them after the tensors have been loaded and
    moved to their destination device(s).

    Args:
    data: The input data to be saved. Any nested combination of Python objects
        (list, tuples, sets, dicts, ...).
    file_or_path: The destination for the data saving operation. Either a file
        path or a Python file object. If `master_only` is ``False`` the path or
        file objects must point to different destinations as otherwise all the
        writes from the same host will override each other.
    master_only (bool, optional): Whether only the master device should save the
        data. If False, the `file_or_path` argument should be a different file or
        path for each of the ordinals taking part to the replication, otherwise
        all the replicas on the same host will be writing to the same location.
        Default: True
    global_master (bool, optional): When ``master_only`` is ``True`` this flag
        controls whether every host's master (if ``global_master`` is ``False``)
        saves the content, or only the global master (ordinal 0).
        Default: False
    """
    should_write_data = not master_only or xm.is_master_ordinal(
        local=not global_master)

    cpu_data = xm._maybe_convert_to_cpu(data, convert=should_write_data)
    if should_write_data:
        torch.save(cpu_data, file_or_path)
    if rendezvous:
        xm.rendezvous('torch_xla.core.xla_model.save')
예제 #2
0
def save_xla_ckpt(ckpt, file_or_path):
    """
    Similar to xm.save, but only try to convert "model" and "optimizer" in an MMF
    checkpoint to CPU, since they hold PyTorch tensors. Other items like lr_scheduler
    often cannot be saved with xm.save due to its errors in handling mappingproxy.

    Only save on the global main process (which is different from the default behavior
    of xm.save that saves a checkpoint on each node).
    """
    should_write_data = is_main()

    is_full_ckpt = isinstance(ckpt,
                              dict) and "model" in ckpt and "optimizer" in ckpt
    if is_full_ckpt:
        ckpt["model"] = xm._maybe_convert_to_cpu(ckpt["model"],
                                                 convert=should_write_data)
        ckpt["optimizer"] = xm._maybe_convert_to_cpu(ckpt["optimizer"],
                                                     convert=should_write_data)
    else:
        ckpt = xm._maybe_convert_to_cpu(ckpt, convert=should_write_data)

    if should_write_data:
        torch.save(ckpt, file_or_path)
    xm.rendezvous("mmf.utils.checkpoint.save_xla_ckpt")
예제 #3
0
파일: utils.py 프로젝트: stjordanis/fairseq
def xla_device_to_cpu(dat):
    import torch_xla.core.xla_model as xm

    return xm._maybe_convert_to_cpu(dat)