예제 #1
0
    def _broadcast_state_dict(self) -> None:
        """Broadcast this rank's state shard, discard others"""

        # Default to CPU space to gain some memory headroom
        local_cpu_state = recursive_copy_to_device(self.local_state_dict(),
                                                   non_blocking=True,
                                                   device=torch.device("cpu"))

        for rank in range(self.world_size):
            if rank == self.rank:
                # Send the state to the reference replica
                logging.debug(
                    "Sending the sharded optimizer state to the reference replica from rank %s",
                    rank,
                )
                dist.broadcast_object_list([local_cpu_state],
                                           src=self.global_rank,
                                           group=self.group)
            else:
                global_rank = self.get_global_rank(self.group, rank)

                # Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
                dist.broadcast_object_list([0],
                                           src=global_rank,
                                           group=self.group)
예제 #2
0
def _broadcast_processed_optim_state_dict(
    processed_optim_state_dict: Optional[Dict[str, Any]],
    rank: int,
    group,
    device: torch.device,
) -> Dict[str, Any]:
    """
    Broadcasts the processed optimizer state dict from rank 0 to all ranks.

    Args:
        processed_optim_state_dict (Optional[Dict[str, Any]]): The flattened
            optimizer state dict with positive-dimension tensor states replaced
            with metadata if on rank 0; ignored otherwise.
        device (torch.device): Device to move zero-dimension tensors post-
            broadcast.

    Returns:
        Dict[str, Any]: The processed optimizer state dict.
    """
    # Broadcast the two data structures rank 0 to all ranks
    obj_list = [processed_optim_state_dict] if rank == 0 \
        else [None]
    dist.broadcast_object_list(obj_list, src=0, group=group)
    processed_optim_state_dict = obj_list[0]  # type: ignore[assignment]
    assert processed_optim_state_dict is not None
    # Move zero-dimension tensors to `device`
    for param_state in processed_optim_state_dict["state"].values():
        for state_name, value in param_state.items():
            if _is_zero_dim_tensor(value):
                param_state[state_name] = value.to(device)
    return processed_optim_state_dict
예제 #3
0
    def test_broadcast_object_list(self):
        val = 99 if self.rank == 0 else None
        object_list = [val] * dist.get_world_size()
        # TODO test with broadcast_object_list's device argument
        dist.broadcast_object_list(object_list=object_list)

        self.assertEqual(99, object_list[0])
예제 #4
0
    def _collect_sharded_states(self) -> List[Dict[str, Any]]:
        """Collect all the state shards, in CPU memory."""
        all_states = []

        for rank in range(self.world_size):
            if rank == self.rank:
                logging.debug("Saving self state")
                all_states.append(
                    recursive_copy_to_device(self.local_state_dict(),
                                             non_blocking=True,
                                             device=torch.device("cpu")))

                # Sync with other replicas
                dist.broadcast_object_list([0],
                                           src=self.global_rank,
                                           group=self.group)
            else:
                # Fetch the optim state from the other replicas
                global_rank = self.get_global_rank(self.group, rank)

                replica_state = [0]
                dist.broadcast_object_list(replica_state,
                                           src=global_rank,
                                           group=self.group)

                all_states.append(
                    recursive_copy_to_device(replica_state[0],
                                             non_blocking=True,
                                             device=torch.device("cpu")))

                logging.debug("State from rank %s received", rank)

        return all_states
예제 #5
0
def sync_object_ranks(something_to_sync: Any, reference_rank: int,
                      device: torch.device) -> Any:
    package = [something_to_sync]
    dist.broadcast_object_list(package,
                               src=reference_rank,
                               group=dist.group.WORLD)
    package_sync = package[0]
    return package_sync
 def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None):
     """Broadcasts the full optimizer state dict in place of using
     ``torch.save()`` and ``torch.load()`` so that all ranks can have it."""
     obj_list = [full_osd]
     dist.broadcast_object_list(
         obj_list, src=0, group=group,
     )
     full_osd = obj_list[0]
     return full_osd
예제 #7
0
파일: oss.py 프로젝트: Jawad1347/fairscale
    def _collect_sharded_states(self) -> List[Dict[str, Any]]:
        """Collect all the state shards, in CPU memory."""
        all_states = []

        for rank in range(self.world_size):
            if rank == self.rank:
                logging.debug("Saving self state")
                all_states.append(
                    recursive_copy_to_device(self.local_state_dict(),
                                             non_blocking=True,
                                             device=torch.device("cpu")))

                # Sync with other replicas
                if _torch_broadcast_object:
                    # torch native object broadcast
                    dist.broadcast_object_list([0],
                                               src=self.global_rank,
                                               group=self.group)
                else:
                    # legacy compatibility for old torch versions
                    broadcast_object(
                        torch.tensor([0],
                                     dtype=torch.uint8,
                                     device=self._device),
                        src_rank=self.global_rank,
                        group=self.group,
                        dist_device=self._device,
                    )
            else:
                # Fetch the optim state from the other replicas
                global_rank = self.get_global_rank(self.group, rank)

                if _torch_broadcast_object:
                    replica_state_l = [0]
                    dist.broadcast_object_list(replica_state_l,
                                               src=global_rank,
                                               group=self.group)
                    replica_state = replica_state_l[0]
                else:
                    replica_state = broadcast_object(
                        torch.tensor([0],
                                     dtype=torch.uint8,
                                     device=self._device),
                        src_rank=global_rank,
                        group=self.group,
                        dist_device=self._device,
                    )

                all_states.append(
                    recursive_copy_to_device(replica_state,
                                             non_blocking=True,
                                             device=torch.device("cpu")))

                logging.debug("State from rank %s received", rank)

        return all_states
예제 #8
0
def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
    dist_init(rank, world_size, tempfile_name)
    device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE

    # Run a dummy step so that the optimizer state dict exists
    batch, input_width, hidden, target_width = 3, 3, 3, 5
    target = torch.rand((batch, target_width), device=device)
    inputs = torch.rand((batch, input_width), device=device)

    model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden),
                                torch.nn.Linear(hidden, target_width))
    model.to(device)

    loss_fn = torch.nn.L1Loss()
    loss_fn.to(device)

    # With SGD, Momentum is required to get a state to shard
    optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99)

    def closure():
        optimizer.zero_grad()
        output = model(inputs)
        loss = loss_fn(output, target)
        loss.backward()
        return loss

    _ = optimizer.step(closure=closure)

    # Update the optimizer state on the reference rank
    optimizer.consolidate_state_dict(recipient_rank=reference_rank)

    # Fetch the state on the reference rank
    # - check that it has the correct size
    # - load it again
    if rank == reference_rank:
        optimizer_state_dict = optimizer.state_dict()
        assert len(optimizer_state_dict["state"]) == world_size
    else:
        optimizer_state_dict = {}

    optim_state = [optimizer_state_dict]
    if _torch_broadcast_object:
        dist.broadcast_object_list(optim_state,
                                   src=reference_rank,
                                   group=dist.group.WORLD)
        optimizer_state_dict = optim_state[0]
    else:
        optimizer_state_dict = optim.utils.broadcast_object(
            optimizer_state_dict,
            src_rank=reference_rank,
            group=dist.group.WORLD,
            dist_device=device)

    # Load the optimizer state dict
    optimizer.load_state_dict(optimizer_state_dict)
    dist.destroy_process_group()
예제 #9
0
 def test_step(self, split="val"):
     if self.device == 0:
         metric, loss = self._test_step()
         val_loss = float(loss[split])
         val_metric = float(metric[split])
         object_list = [val_metric, val_loss]
     else:
         object_list = [None, None]
     dist.broadcast_object_list(object_list, src=0)
     return object_list[0], object_list[1]
예제 #10
0
파일: utils.py 프로젝트: huaxz1986/pytorch
 def broadcast_object(self, object: Optional[T]) -> T:
     """
     Same as c10d::broadcast_object_list but works without distributed enabled.
     """
     object_list = [object]
     if self.use_dist:
         dist.broadcast_object_list(object_list=object_list,
                                    group=self.group,
                                    src=self.coordinator_rank)
     return cast(T, object_list[0])
예제 #11
0
    def gen_metadata(self) -> Metadata:
        module = TestModule()
        # compute the default saved metadata (must pass include_non_replicated_tensors or we'll get incomplete MD)
        metadata, _, _ = _prepare(module.state_dict(), True)

        # _prepare only produc
        metadata = [metadata]
        dist.broadcast_object_list(metadata)

        return metadata[0]
예제 #12
0
def sync_object_ranks(something_to_sync: Any, reference_rank: int,
                      device: torch.device) -> Any:
    if _torch_broadcast_object:
        package = [something_to_sync]
        dist.broadcast_object_list(package,
                                   src=reference_rank,
                                   group=dist.group.WORLD)
        package_sync = package[0]
    else:
        package_sync = optim.utils.broadcast_object(something_to_sync,
                                                    src_rank=reference_rank,
                                                    group=dist.group.WORLD,
                                                    dist_device=device)

    return package_sync
    def test_read_write_shard_tensor(self) -> None:
        paths = [tempfile.mkdtemp()]
        dist.broadcast_object_list(paths)

        path = paths[0]

        # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0",
                "rank:1",
            ],
        )

        model_to_save = MyShardedModel1(spec, init_rrefs=False)

        # Test save
        model_to_save._register_state_dict_hook(state_dict_hook)
        state_dict_to_save = model_to_save.state_dict()

        fs_writer = FileSystemWriter(path=path)
        save_state_dict(state_dict=state_dict_to_save,
                        storage_writer=fs_writer)

        dist.barrier()

        # Create a new model
        model_to_load = MyShardedModel1(spec, init_rrefs=False)
        # This is not the correct hook for loading the state dict
        # model_to_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
        model_to_load._register_state_dict_hook(state_dict_hook)
        state_dict_to_load_to = model_to_load.state_dict()

        dist.barrier()

        with self.assertRaises(AssertionError):
            assert_state_dict_equal(self, state_dict_to_load_to,
                                    state_dict_to_save)

        # Test load.
        fs_reader = FileSystemReader(path=path)
        load_state_dict(state_dict=state_dict_to_load_to,
                        storage_reader=fs_reader)

        assert_state_dict_equal(self, state_dict_to_load_to,
                                state_dict_to_save)
        dist.barrier()
예제 #14
0
파일: oss.py 프로젝트: LuisFalva/fairscale
    def _broadcast_state_dict(self) -> None:
        """Broadcast this rank's state shard, discard others"""

        # Default to CPU space to gain some memory headroom
        local_cpu_state = recursive_copy_to_device(self.local_state_dict(),
                                                   non_blocking=True,
                                                   device=torch.device("cpu"))

        # Tensor cannot be really empty, even if its size is meaningless
        dummy_sync_tensor = torch.tensor([1], device=self._device)

        for rank in range(self.world_size):
            if rank == self.rank:
                # Send the state to the reference replica
                logging.debug(
                    "Sending the sharded optimizer state to the reference replica from rank %s",
                    rank,
                )
                if _torch_broadcast_object:
                    # torch native object broadcast
                    dist.broadcast_object_list([local_cpu_state],
                                               src=self.global_rank,
                                               group=self.group)
                else:
                    # legacy compatibility for old torch versions
                    broadcast_object(self.local_state_dict(),
                                     src_rank=self.global_rank,
                                     group=self.group,
                                     dist_device=self._device)
            else:
                global_rank = self.get_global_rank(self.group, rank)

                # Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
                if _torch_broadcast_object:
                    dist.broadcast_object_list([dummy_sync_tensor],
                                               src=global_rank,
                                               group=self.group)
                else:
                    broadcast_object(
                        torch.tensor([dummy_sync_tensor],
                                     dtype=torch.uint8,
                                     device=self._device),
                        src_rank=global_rank,
                        group=self.group,
                        dist_device=self._device,
                    )
예제 #15
0
파일: trainer.py 프로젝트: rpatil524/cogdl
    def distributed_test(self, model_w: ModelWrapper, loader, rank, fn):
        model_w.eval()
        # if rank == 0:
        if dist.get_rank() == 0:
            if self.cpu_inference:
                model_w.to("cpu")
                _device = "cpu"
            else:
                _device = rank
            with torch.no_grad():
                result = fn(model_w, loader, _device)
            model_w.to(rank)

            object_list = [result]
        else:
            object_list = [None]
        dist.broadcast_object_list(object_list, src=0)
        return object_list[0]
예제 #16
0
    def test_distributed_checkpoint(self, state_dict_type) -> None:
        with enable_wrap(wrapper_cls=FSDP):
            torch.manual_seed(100)
            model = wrap(SkipModel(double_nest=True))
            torch.manual_seed(200)
            new_model = wrap(SkipModel(double_nest=True))

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertNotEqual(params, new_params)

        with tempfile.TemporaryDirectory() as path:
            paths = [path]
            dist.broadcast_object_list(paths)
            path = paths[0]
            writer = FileSystemWriter(path)
            reader = FileSystemReader(path)
            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = model.state_dict()

            save_state_dict(state_dict, writer)

            with FSDP.state_dict_type(model,
                                      state_dict_type), FSDP.state_dict_type(
                                          new_model, state_dict_type):
                state_dict = new_model.state_dict()
                load_state_dict(state_dict, reader)
                new_model.load_state_dict(state_dict)

        with FullyShardedDataParallel.summon_full_params(
                model), FullyShardedDataParallel.summon_full_params(new_model):
            params = list(model.parameters())
            new_params = list(new_model.parameters())
            self.assertEqual(params, new_params)
예제 #17
0
def _broadcast_processed_optim_state_dict(
    processed_optim_state_dict: Optional[Dict[str, Any]],
    fsdp_flat_param_ids: Optional[Set[int]],
    rank: int,
    group,
    device: torch.device,
) -> Tuple[Dict[str, Any], Set[int]]:
    """
    Broadcasts the processed optimizer state dict and the accompanying FSDP
    parameter IDs from rank 0 to all ranks.

    Args:
        processed_optim_state_dict (Optional[Dict[str, Any]]): The full
            optimizer state dict with positive-dimension tensor states replaced
            with metadata if on rank 0; ignored otherwise.
        fsdp_flat_param_ids (Optional[Set[int]]): Parameter IDs corresponding
            to FSDP parameters if on rank 0; ignored otherwise.
        device (torch.device): Device to move zero-dimension tensors post-
            broadcast.

    Returns:
        Tuple[Dict[str, Any], Set[int]]: The processed optimizer state dict
        and the parameter IDs corresponding to FSDP parameters.
    """
    # Broadcast the two data structures rank 0 to all ranks
    obj_list = [processed_optim_state_dict, fsdp_flat_param_ids] if rank == 0 \
        else [None, None]
    dist.broadcast_object_list(obj_list, src=0, group=group)
    processed_optim_state_dict, fsdp_flat_param_ids = obj_list  # type: ignore[assignment]
    assert processed_optim_state_dict is not None
    assert fsdp_flat_param_ids is not None
    # Move zero-dimension tensors to `device`
    for param_state in processed_optim_state_dict["state"].values():
        for state_name, value in param_state.items():
            if _is_zero_dim_tensor(value):
                param_state[state_name] = value.to(device)
    return processed_optim_state_dict, fsdp_flat_param_ids
예제 #18
0
    def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
        """Update the consolidated state_dict list, one per rank.

        .. warning: This needs to be called on all replicas"""

        # Sync lr and other attributes in case its been updated
        self._update_param_groups()

        empty_messenger = torch.tensor([0],
                                       dtype=torch.uint8,
                                       device=self._device)

        # Pull the sharded state from all the other replicas
        # Store all the states in order, rank by rank

        # NOTE: In practice, `broadcast` is used, which is wasteful (gather would have been appropriate)
        # compatibility issues with some backends make the use of broadcast mandatory for now.
        # a possible follow up would be to move all sharded state management to RPC RRef

        self._all_states = []
        for rank in range(self.world_size):
            global_rank = _get_global_rank(self.group, rank)

            # This rank collects the whole state
            if self.rank == recipient_rank:
                if rank == self.rank:
                    self._all_states.append(
                        _recursive_copy_to_device(self.optim.state_dict(),
                                                  non_blocking=True,
                                                  device=torch.device("cpu")))
                else:
                    # Fetch the optim state from the other replicas
                    replica_state = [empty_messenger]
                    dist.broadcast_object_list(object_list=replica_state,
                                               src=global_rank,
                                               group=self.group)

                    self._all_states.append(
                        _recursive_copy_to_device(replica_state[0],
                                                  non_blocking=True,
                                                  device=torch.device("cpu")))
            else:
                # Acknowledge broadcasts, and send this rank's shard when needed
                # Default to CPU space to gain some memory headroom
                if rank == self.rank:
                    # Send the state to the reference replica
                    dist.broadcast_object_list(
                        object_list=[self.optim.state_dict()],
                        src=self.global_rank,
                        group=self.group)

                elif rank != recipient_rank:
                    # Discard this tensor/rank, broadcast was being use for compatibility reasons
                    dist.broadcast_object_list(object_list=[empty_messenger],
                                               src=global_rank,
                                               group=self.group)
예제 #19
0
def _broadcast_processed_optim_state_dict(
    processed_optim_state_dict: Optional[Dict[str, Any]],
    rank: int,
    group,
) -> Dict[str, Any]:
    """
    Broadcasts the processed optimizer state dict from rank 0 to all ranks.

    Args:
        processed_optim_state_dict (Optional[Dict[str, Any]]): The flattened
            optimizer state dict with positive-dimension tensor states replaced
            with metadata if on rank 0; ignored otherwise.

    Returns:
        Dict[str, Any]: The processed optimizer state dict.
    """
    # Broadcast the two data structures rank 0 to all ranks
    obj_list = [processed_optim_state_dict] if rank == 0 \
        else [None]
    dist.broadcast_object_list(obj_list, src=0, group=group)
    processed_optim_state_dict = obj_list[0]  # type: ignore[assignment]
    assert processed_optim_state_dict is not None
    # Keep zero-dimension tensors on CPU
    return processed_optim_state_dict
예제 #20
0
 def _broadcast_list_from_main(self, val: Any) -> Any:
     if not self._is_ddp:
         return val
     dist.broadcast_object_list(val, src=0, group=self._gloo_handle)
     return val
예제 #21
0
 def _broadcast_state_dict(self, state_dict):
     olist = [state_dict if self.rank == 0 else None]
     dist.broadcast_object_list(olist)
     return olist[0]
예제 #22
0
            def check_optimizer_equivalence(
                    optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer",
                                      torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(
                    params=model.parameters(),
                    optimizer_class=optimizer,
                    lr=1e-3)
                sharded_ddp_model = DDP(module=model,
                                        device_ids=[self.rank],
                                        broadcast_buffers=True,
                                        find_unused_parameters=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(),
                                          lr=1e-3)
                ddp_model = DDP(ddp_model_single,
                                device_ids=[self.rank],
                                broadcast_buffers=True,
                                find_unused_parameters=True)

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params(sharded_ddp_model, ddp_model,
                                        "Models differ from the start")

                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(
                            input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor,
                                    ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(
                        torch.Tensor,
                        sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params(sharded_ddp_model, ddp_model,
                                            "Models differ after a step")

                # The models should stay the same in between the ranks
                for i in range(BATCHS):
                    check_step()

                    # Change the models trainability, check that parity is maintained
                    # only check after a couple of constant batchs to go through both regimes
                    if i > BATCHS // 2:
                        next(ddp_model.parameters()).requires_grad = bool(i %
                                                                          2)
                        next(sharded_ddp_model.parameters()
                             ).requires_grad = bool(i % 2)

                # Check that the checkpoints are compatible
                reference_rank = 0
                # - get states
                ddp_state_dict = ddp_optimizer.state_dict()
                sharded_optimizer.consolidate_state_dict(to=reference_rank)
                sharded_optim_state_dict = [
                    sharded_optimizer.state_dict()
                    if self.rank == reference_rank else {}
                ]
                dist.broadcast_object_list(sharded_optim_state_dict,
                                           src=reference_rank,
                                           group=dist.group.WORLD)
                sharded_optim_state_dict = sharded_optim_state_dict[0]

                # - cross load the states
                # run one step and check that the models are still the same
                ddp_state_dict_ref = copy.deepcopy(
                    ddp_state_dict)  # OSS will remove some states
                ddp_optimizer.load_state_dict(
                    sharded_optim_state_dict)  # mixup on purpose !
                sharded_optimizer.load_state_dict(ddp_state_dict)
                check_step()

                #  - self load, rewind, check no problem
                # run one step and check that the models are still the same
                ddp_optimizer.load_state_dict(ddp_state_dict_ref)
                sharded_optimizer.load_state_dict(sharded_optim_state_dict)
                check_step()
예제 #23
0
import torch.distributed as dist

if dist.get_rank() == 0:
    objects = ["f", 1]
else:
    objects = [None, None]

# ruleid: pickles-in-torch-distributed
dist.broadcast_object_list(objects, src=0)

# ruleid: pickles-in-torch-distributed
dist.all_gather_object(output, gather_objects[dist.get_rank()])

# ruleid: pickles-in-torch-distributed
dist.gather_object(gather_objects[dist.get_rank()],
                   output if dist.get_rank() == 0 else None,
                   dst=0)

# ruleid: pickles-in-torch-distributed
dist.scatter_object_list(output_list, objects, src=0)
예제 #24
0
def train(hyp, opt, device,
          callbacks):  # hyp is path/to/hyp.yaml or hyp dictionary
    save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
        Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
        opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
    callbacks.run('on_pretrain_routine_start')

    # Directories
    w = save_dir / 'weights'  # weights dir
    (w.parent if evolve else w).mkdir(parents=True, exist_ok=True)  # make dir
    last, best = w / 'last.pt', w / 'best.pt'

    # Hyperparameters
    if isinstance(hyp, str):
        with open(hyp, errors='ignore') as f:
            hyp = yaml.safe_load(f)  # load hyps dict
    LOGGER.info(
        colorstr('hyperparameters: ') + ', '.join(f'{k}={v}'
                                                  for k, v in hyp.items()))

    # Save run settings
    if not evolve:
        with open(save_dir / 'hyp.yaml', 'w') as f:
            yaml.safe_dump(hyp, f, sort_keys=False)
        with open(save_dir / 'opt.yaml', 'w') as f:
            yaml.safe_dump(vars(opt), f, sort_keys=False)

    # Loggers
    data_dict = None
    if RANK in {-1, 0}:
        loggers = Loggers(save_dir, weights, opt, hyp,
                          LOGGER)  # loggers instance
        if loggers.wandb:
            data_dict = loggers.wandb.data_dict
            if resume:
                weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size

        # Register actions
        for k in methods(loggers):
            callbacks.register_action(k, callback=getattr(loggers, k))

    # Config
    plots = not evolve and not opt.noplots  # create plots
    cuda = device.type != 'cpu'
    init_seeds(opt.seed + 1 + RANK, deterministic=True)
    with torch_distributed_zero_first(LOCAL_RANK):
        data_dict = data_dict or check_dataset(data)  # check if None
    train_path, val_path = data_dict['train'], data_dict['val']
    nc = 1 if single_cls else int(data_dict['nc'])  # number of classes
    names = ['item'] if single_cls and len(
        data_dict['names']) != 1 else data_dict['names']  # class names
    assert len(
        names
    ) == nc, f'{len(names)} names found for nc={nc} dataset in {data}'  # check
    is_coco = isinstance(val_path, str) and val_path.endswith(
        'coco/val2017.txt')  # COCO dataset

    # Model
    check_suffix(weights, '.pt')  # check weights
    pretrained = weights.endswith('.pt')
    if pretrained:
        with torch_distributed_zero_first(LOCAL_RANK):
            weights = attempt_download(
                weights)  # download if not found locally
        ckpt = torch.load(weights, map_location='cpu'
                          )  # load checkpoint to CPU to avoid CUDA memory leak
        model = Model(cfg or ckpt['model'].yaml,
                      ch=3,
                      nc=nc,
                      anchors=hyp.get('anchors')).to(device)  # create
        exclude = [
            'anchor'
        ] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
        csd = ckpt['model'].float().state_dict(
        )  # checkpoint state_dict as FP32
        csd = intersect_dicts(csd, model.state_dict(),
                              exclude=exclude)  # intersect
        model.load_state_dict(csd, strict=False)  # load
        LOGGER.info(
            f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}'
        )  # report
    else:
        model = Model(cfg, ch=3, nc=nc,
                      anchors=hyp.get('anchors')).to(device)  # create
    amp = check_amp(model)  # check AMP

    # Freeze
    freeze = [
        f'model.{x}.'
        for x in (freeze if len(freeze) > 1 else range(freeze[0]))
    ]  # layers to freeze
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        if any(x in k for x in freeze):
            LOGGER.info(f'freezing {k}')
            v.requires_grad = False

    # Image size
    gs = max(int(model.stride.max()), 32)  # grid size (max stride)
    imgsz = check_img_size(opt.imgsz, gs,
                           floor=gs * 2)  # verify imgsz is gs-multiple

    # Batch size
    if RANK == -1 and batch_size == -1:  # single-GPU only, estimate best batch size
        batch_size = check_train_batch_size(model, imgsz, amp)
        loggers.on_params_update({"batch_size": batch_size})

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
    LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
    optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'],
                                hyp['momentum'], hyp['weight_decay'])

    # Scheduler
    if opt.cos_lr:
        lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
    else:
        lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'
                                                                   ]  # linear
    scheduler = lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lf)  # plot_lr_scheduler(optimizer, scheduler, epochs)

    # EMA
    ema = ModelEMA(model) if RANK in {-1, 0} else None

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # EMA
        if ema and ckpt.get('ema'):
            ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
            ema.updates = ckpt['updates']

        # Epochs
        start_epoch = ckpt['epoch'] + 1
        if resume:
            assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
        if epochs < start_epoch:
            LOGGER.info(
                f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs."
            )
            epochs += ckpt['epoch']  # finetune additional epochs

        del ckpt, csd

    # DP mode
    if cuda and RANK == -1 and torch.cuda.device_count() > 1:
        LOGGER.warning(
            'WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
            'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.'
        )
        model = torch.nn.DataParallel(model)

    # SyncBatchNorm
    if opt.sync_bn and cuda and RANK != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        LOGGER.info('Using SyncBatchNorm()')

    # Trainloader
    train_loader, dataset = create_dataloader(
        train_path,
        imgsz,
        batch_size // WORLD_SIZE,
        gs,
        single_cls,
        hyp=hyp,
        augment=True,
        cache=None if opt.cache == 'val' else opt.cache,
        rect=opt.rect,
        rank=LOCAL_RANK,
        workers=workers,
        image_weights=opt.image_weights,
        quad=opt.quad,
        prefix=colorstr('train: '),
        shuffle=True)
    mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max())  # max label class
    nb = len(train_loader)  # number of batches
    assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'

    # Process 0
    if RANK in {-1, 0}:
        val_loader = create_dataloader(val_path,
                                       imgsz,
                                       batch_size // WORLD_SIZE * 2,
                                       gs,
                                       single_cls,
                                       hyp=hyp,
                                       cache=None if noval else opt.cache,
                                       rect=True,
                                       rank=-1,
                                       workers=workers * 2,
                                       pad=0.5,
                                       prefix=colorstr('val: '))[0]

        if not resume:
            labels = np.concatenate(dataset.labels, 0)
            # c = torch.tensor(labels[:, 0])  # classes
            # cf = torch.bincount(c.long(), minlength=nc) + 1.  # frequency
            # model._initialize_biases(cf.to(device))
            if plots:
                plot_labels(labels, names, save_dir)

            # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset,
                              model=model,
                              thr=hyp['anchor_t'],
                              imgsz=imgsz)
            model.half().float()  # pre-reduce anchor precision

        callbacks.run('on_pretrain_routine_end')

    # DDP mode
    if cuda and RANK != -1:
        model = smart_DDP(model)

    # Model attributes
    nl = de_parallel(
        model).model[-1].nl  # number of detection layers (to scale hyps)
    hyp['box'] *= 3 / nl  # scale to layers
    hyp['cls'] *= nc / 80 * 3 / nl  # scale to classes and layers
    hyp['obj'] *= (imgsz / 640)**2 * 3 / nl  # scale to image size and layers
    hyp['label_smoothing'] = opt.label_smoothing
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.class_weights = labels_to_class_weights(
        dataset.labels, nc).to(device) * nc  # attach class weights
    model.names = names

    # Start training
    t0 = time.time()
    nw = max(round(hyp['warmup_epochs'] * nb),
             100)  # number of warmup iterations, max(3 epochs, 100 iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    last_opt_step = -1
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0
               )  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
    scheduler.last_epoch = start_epoch - 1  # do not move
    scaler = torch.cuda.amp.GradScaler(enabled=amp)
    stopper, stop = EarlyStopping(patience=opt.patience), False
    compute_loss = ComputeLoss(model)  # init loss class
    callbacks.run('on_train_start')
    LOGGER.info(
        f'Image sizes {imgsz} train, {imgsz} val\n'
        f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
        f"Logging results to {colorstr('bold', save_dir)}\n"
        f'Starting training for {epochs} epochs...')
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        callbacks.run('on_train_epoch_start')
        model.train()

        # Update image weights (optional, single-GPU only)
        if opt.image_weights:
            cw = model.class_weights.cpu().numpy() * (
                1 - maps)**2 / nc  # class weights
            iw = labels_to_image_weights(dataset.labels,
                                         nc=nc,
                                         class_weights=cw)  # image weights
            dataset.indices = random.choices(range(dataset.n),
                                             weights=iw,
                                             k=dataset.n)  # rand weighted idx

        # Update mosaic border (optional)
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = torch.zeros(3, device=device)  # mean losses
        if RANK != -1:
            train_loader.sampler.set_epoch(epoch)
        pbar = enumerate(train_loader)
        LOGGER.info(
            ('\n' + '%10s' * 7) %
            ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
        if RANK in {-1, 0}:
            pbar = tqdm(
                pbar, total=nb,
                bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')  # progress bar
        optimizer.zero_grad()
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            callbacks.run('on_train_batch_start')
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device, non_blocking=True).float(
            ) / 255  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(ni, xi, [
                        hyp['warmup_bias_lr'] if j == 0 else 0.0,
                        x['initial_lr'] * lf(epoch)
                    ])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(
                            ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = nn.functional.interpolate(imgs,
                                                     size=ns,
                                                     mode='bilinear',
                                                     align_corners=False)

            # Forward
            with torch.cuda.amp.autocast(amp):
                pred = model(imgs)  # forward
                loss, loss_items = compute_loss(
                    pred, targets.to(device))  # loss scaled by batch_size
                if RANK != -1:
                    loss *= WORLD_SIZE  # gradient averaged between devices in DDP mode
                if opt.quad:
                    loss *= 4.

            # Backward
            scaler.scale(loss).backward()

            # Optimize
            if ni - last_opt_step >= accumulate:
                scaler.step(optimizer)  # optimizer.step
                scaler.update()
                optimizer.zero_grad()
                if ema:
                    ema.update(model)
                last_opt_step = ni

            # Log
            if RANK in {-1, 0}:
                mloss = (mloss * i + loss_items) / (i + 1
                                                    )  # update mean losses
                mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB)
                pbar.set_description(('%10s' * 2 + '%10.4g' * 5) %
                                     (f'{epoch}/{epochs - 1}', mem, *mloss,
                                      targets.shape[0], imgs.shape[-1]))
                callbacks.run('on_train_batch_end', ni, model, imgs, targets,
                              paths, plots)
                if callbacks.stop_training:
                    return
            # end batch ------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x['lr'] for x in optimizer.param_groups]  # for loggers
        scheduler.step()

        if RANK in {-1, 0}:
            # mAP
            callbacks.run('on_train_epoch_end', epoch=epoch)
            ema.update_attr(model,
                            include=[
                                'yaml', 'nc', 'hyp', 'names', 'stride',
                                'class_weights'
                            ])
            final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
            if not noval or final_epoch:  # Calculate mAP
                results, maps, _ = val.run(data_dict,
                                           batch_size=batch_size //
                                           WORLD_SIZE * 2,
                                           imgsz=imgsz,
                                           model=ema.ema,
                                           single_cls=single_cls,
                                           dataloader=val_loader,
                                           save_dir=save_dir,
                                           plots=False,
                                           callbacks=callbacks,
                                           compute_loss=compute_loss)

            # Update best mAP
            fi = fitness(np.array(results).reshape(
                1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
            stop = stopper(epoch=epoch, fitness=fi)  # early stop check
            if fi > best_fitness:
                best_fitness = fi
            log_vals = list(mloss) + list(results) + lr
            callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness,
                          fi)

            # Save model
            if (not nosave) or (final_epoch and not evolve):  # if save
                ckpt = {
                    'epoch': epoch,
                    'best_fitness': best_fitness,
                    'model': deepcopy(de_parallel(model)).half(),
                    'ema': deepcopy(ema.ema).half(),
                    'updates': ema.updates,
                    'optimizer': optimizer.state_dict(),
                    'wandb_id':
                    loggers.wandb.wandb_run.id if loggers.wandb else None,
                    'date': datetime.now().isoformat()
                }

                # Save last, best and delete
                torch.save(ckpt, last)
                if best_fitness == fi:
                    torch.save(ckpt, best)
                if opt.save_period > 0 and epoch % opt.save_period == 0:
                    torch.save(ckpt, w / f'epoch{epoch}.pt')
                del ckpt
                callbacks.run('on_model_save', last, epoch, final_epoch,
                              best_fitness, fi)

        # EarlyStopping
        if RANK != -1:  # if DDP training
            broadcast_list = [stop if RANK == 0 else None]
            dist.broadcast_object_list(broadcast_list,
                                       0)  # broadcast 'stop' to all ranks
            if RANK != 0:
                stop = broadcast_list[0]
        if stop:
            break  # must break all DDP ranks

        # end epoch ----------------------------------------------------------------------------------------------------
    # end training -----------------------------------------------------------------------------------------------------
    if RANK in {-1, 0}:
        LOGGER.info(
            f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.'
        )
        for f in last, best:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
                if f is best:
                    LOGGER.info(f'\nValidating {f}...')
                    results, _, _ = val.run(
                        data_dict,
                        batch_size=batch_size // WORLD_SIZE * 2,
                        imgsz=imgsz,
                        model=attempt_load(f, device).half(),
                        iou_thres=0.65 if is_coco else
                        0.60,  # best pycocotools results at 0.65
                        single_cls=single_cls,
                        dataloader=val_loader,
                        save_dir=save_dir,
                        save_json=is_coco,
                        verbose=True,
                        plots=plots,
                        callbacks=callbacks,
                        compute_loss=compute_loss)  # val best model with plots
                    if is_coco:
                        callbacks.run('on_fit_epoch_end',
                                      list(mloss) + list(results) + lr, epoch,
                                      best_fitness, fi)

        callbacks.run('on_train_end', last, best, plots, epoch, results)

    torch.cuda.empty_cache()
    return results
예제 #25
0
def train(model, dataset, args, rank, evaluator, loss_fn):
    print(f"Running on rank {rank}.")

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(args.master_port)

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    model = copy.deepcopy(model).to(rank)
    model = DDP(model, device_ids=[rank])

    data = dataset[0]

    train_dataset, train_loader = get_train_loader(dataset, args, rank)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    epoch_iter = tqdm(range(args.max_epoch)) if rank == 0 else range(
        args.max_epoch)
    patience = 0
    max_score = 0
    min_loss = np.inf
    best_model = None
    for epoch in epoch_iter:
        train_dataset.shuffle()
        train_step(model, train_loader, optimizer, rank)
        if (epoch + 1) % args.eval_step == 0:
            if rank == 0:
                acc, loss = test_step(model.module, data, evaluator, loss_fn)
                train_acc = acc["train"]
                val_acc = acc["val"]
                val_loss = loss["val"]
                epoch_iter.set_description(
                    f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}"
                )
                model = model.to(rank)
                object_list = [val_loss, val_acc]
            else:
                object_list = [None, None]
            dist.broadcast_object_list(object_list, src=0)
            val_loss, val_acc = object_list

            if val_loss <= min_loss or val_acc >= max_score:
                if val_loss <= min_loss:
                    best_model = copy.deepcopy(model)
                min_loss = np.min((min_loss, val_loss))
                max_score = np.max((max_score, val_acc))
                patience = 0
            else:
                patience += 1
                if patience == args.patience:
                    break
        dist.barrier()

    if rank == 0:
        os.makedirs("./checkpoints", exist_ok=True)
        checkpoint_path = os.path.join("./checkpoints",
                                       f"{args.model}_{args.dataset}.pt")
        if best_model is not None:
            print(f"Saving model to {checkpoint_path}")
            torch.save(best_model.module.state_dict(), checkpoint_path)

    dist.barrier()

    dist.destroy_process_group()
예제 #26
0
    def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
        """Update the consolidated state_dict list, one per rank.

        Arguments:
            recipient_rank (int): on which rank to materialize the full state dict.
            -1 is a special value, which means that all ranks should have the state

        .. warning: This needs to be called on all replicas"""

        # Sync lr and other attributes in case its been updated
        OSS._sync_param_groups(self.param_groups, self.optim.param_groups)

        # Pull the sharded state from all the other replicas
        # Store all the states in order, rank by rank
        logging.debug("Pulling the sharded optimizer state from all replicas")

        self._all_states = []
        should_collect_state = self.rank == recipient_rank or recipient_rank == -1
        should_send_state = self.rank != recipient_rank

        # NCCL requires CUDA tensors for all communication primitives
        dist_device = torch.device(
            "cuda"
        ) if self.backend == dist.Backend.NCCL else self._default_device

        for rank in range(self.world_size):
            if rank == self.rank:
                if should_collect_state:
                    logging.debug("Saving self state")
                    self._all_states.append(
                        recursive_copy_to_device(self.optim.state_dict(),
                                                 non_blocking=True,
                                                 device=torch.device("cpu")))

                # Sync with other replicas
                state_to_share = (
                    self.optim.state_dict() if should_send_state else
                    torch.tensor([0], dtype=torch.uint8, device=dist_device))
                if _gpu_capabilities_older_than_50():
                    _broadcast_object(state_to_share,
                                      src_rank=self.global_rank,
                                      group=self.group,
                                      dist_device=dist_device)
                else:
                    obj_list = [state_to_share]
                    dist.broadcast_object_list(
                        obj_list,
                        src=self.global_rank,
                        group=self.group,
                    )
            else:
                # Fetch the optim state from the other replicas
                if _gpu_capabilities_older_than_50():
                    replica_state = _broadcast_object(
                        torch.tensor([0],
                                     dtype=torch.uint8,
                                     device=dist_device),
                        src_rank=self._local_to_global_rank[rank],
                        group=self.group,
                        dist_device=dist_device,
                    )
                else:
                    obj_list = [
                        torch.tensor([0],
                                     dtype=torch.uint8,
                                     device=dist_device)
                    ]
                    dist.broadcast_object_list(
                        obj_list,
                        src=self._local_to_global_rank[rank],
                        group=self.group,
                    )
                    replica_state = obj_list[0]

                if should_collect_state:
                    self._all_states.append(
                        recursive_copy_to_device(replica_state,
                                                 non_blocking=True,
                                                 device=torch.device("cpu")))

                logging.debug("State from rank %s received", rank)
예제 #27
0
def save_state_dict(
    state_dict: Dict[str, Any],
    storage_writer: StorageWriter,
    process_group: Optional[dist.ProcessGroup] = None,
    coordinator_rank: int = 0,
    no_dist: bool = False
) -> None:
    """
    Save a distributed model in SPMD style.

    This function is different from ``torch.save()`` as it handles
    ``ShardedTensor`` by having each rank only save their local shards.

    To produce a state_dict with ShardedTensor instances you must call
    ``_register_state_dict_hook`` on the top module with value
    `torch.distributed._shard.sharded_tensor.state_dict_hook` prior to
    calling `state_dict()` on the top module.

    There is no guarantees of Backwards Compatibility across PyTorch versions
    for saved state_dicts.

    If using the `process_group` argument, make sure that only its ranks
    call `save_state_dict` and that all data in state_dict belong to it.

    This function can be used to save a state_dict with an intialized process
    group by passing ``no_dist=True``. This can be used to produce a checkpoint
    that can consumed by load_state_dict is a SPMD fashion.

    Args:
        state_dict (Dict[str, Any]) : A state_dict
        storage_writer (StorageWriter): Instance of StorageWrite use to perform writes.
        process_group (ProcessGroup): ProcessGroup to be used for cross-rank synchronization
        coordinator_rank (int): Rank to use to coordinate the checkpoint, rank0 is used by default
        no_dist (bool): Don't attempt to save in SPMD style. Default to False

    Example:
        >>> my_model = MyModule()
        >>> # We must call this function prior to state_dict()
        >>> my_model._register_state_dict_hook(state_dict_hook)

        >>> model_state_dict = my_model.state_dict()

        >>> fs_storage_writer = torch.distributed._shard.checkpoint.FileSystemWriter("/checkpoint/1")
        >>> torch.distributed._shard.checkpoint.save_state_dict(
        >>>     state_dict=model_state_dict,
        >>>     storage_writer=fs_stroage_writer,
        >>> )

    .. note:: save_state_dict uses collectives to coordinate writes across ranks.
        For NCCL-based process groups, internal tensor representations of objects
        must be moved to the GPU device before communication takes place. In this
        case, the device used is given by ``torch.cuda.current_device()`` and it
        is the user's responsibility to ensure that this is set so that each rank
        has an individual GPU, via ``torch.cuda.set_device()``
    """
    is_coordinator = no_dist or dist.get_rank(process_group) == coordinator_rank

    exceptions: List[Optional[BaseException]] = [None]
    if is_coordinator:
        try:
            storage_writer.prepare()
        except BaseException as e:
            exceptions = [e]

    # Writing can only start once prepare has finished
    if not no_dist:
        dist.broadcast_object_list(exceptions, group=process_group, src=coordinator_rank)

    if exceptions[0] is not None:
        raise CheckpointException("failed to prepare storage", {coordinator_rank : exceptions[0]})

    rank_write_error: Optional[BaseException]
    try:
        (
            metadata,
            bytes_write_requests,
            tensor_write_requests,
        ) = _prepare(state_dict, is_coordinator, process_group)

        combined_writes: List[Union[TensorWriteRequest, BytesWriteRequest]] = []
        combined_writes.extend(tensor_write_requests)
        combined_writes.extend(bytes_write_requests)

        storage_writer.prepare_storage(combined_writes)
        bytes_futures = storage_writer.write_bytes(bytes_write_requests)
        tensor_futures = storage_writer.write_tensors(tensor_write_requests)
        torch.futures.wait_all([bytes_futures, tensor_futures])
        rank_write_error = None
    except BaseException as e:
        rank_write_error = e

    all_errors: List[Optional[BaseException]]
    # collect all write errors
    if not no_dist:
        all_errors = [None] * dist.get_world_size(process_group)
        dist.gather_object(
            obj=rank_write_error,
            object_gather_list=all_errors if is_coordinator else None,
            dst=coordinator_rank
        )
    else:
        all_errors = [rank_write_error]

    result: List[Optional[CheckpointException]] = [None]
    if is_coordinator:
        message: Optional[str] = None
        # gather produces an array of arrays, flatten it
        if any(all_errors):
            message = "Failed to write data"
        else:
            try:
                storage_writer.finish(metadata=metadata)
            except BaseException as e:
                all_errors[coordinator_rank] = e
                message = "Failed to finish checkpoint"

        if message is not None:
            node_failures = {i: err for i, err in enumerate(all_errors) if err is not None}
            result[0] = CheckpointException(message, node_failures)

    if not no_dist:
        dist.broadcast_object_list(
            result,
            group=process_group,
            src=coordinator_rank)

    if result[0] is not None:
        raise result[0]
            def check_optimizer_equivalence(
                    optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer",
                                      torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(
                    params=model.parameters(), optim=optimizer, lr=1e-3)
                sharded_ddp_model = DDP(module=model,
                                        device_ids=[self.rank],
                                        broadcast_buffers=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(),
                                          lr=1e-3)
                ddp_model = DDP(ddp_model_single,
                                device_ids=[self.rank],
                                broadcast_buffers=True)

                def check_same_model_params():
                    for pg, ddp_pg in zip(sharded_optimizer.param_groups,
                                          ddp_optimizer.param_groups):
                        for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                            assert torch.allclose(
                                p, ddp_p, atol=1e-3
                            ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}"

                    for b, ddp_b in zip(sharded_ddp_model.buffers(),
                                        ddp_model.buffers()):
                        assert torch.allclose(
                            b, ddp_b
                        ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer"

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params()

                # The models should stay the same across multiple steps, losses should stay the same
                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(
                            input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor,
                                    ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(
                        torch.Tensor,
                        sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params()

                for i in range(20):
                    check_step()

                # Test state dict save/load/equivalence with pytorch
                # - save state for both
                sharded_optimizer.consolidate_state_dict()
                sharded_optimizer_state_dict = (sharded_optimizer.state_dict()
                                                if self.rank == RECIPIENT_RANK
                                                else torch.zeros(1))
                ddp_state_dict = ddp_optimizer.state_dict()

                #  - sync the saved state with all the ranks
                exchange_list = [sharded_optimizer_state_dict]
                dist.broadcast_object_list(
                    exchange_list,
                    src=RECIPIENT_RANK,
                    group=dist.group.WORLD,
                )
                sharded_optimizer_state_dict = exchange_list[0]

                # - cross load the states
                ddp_optimizer.load_state_dict(sharded_optimizer_state_dict)
                sharded_optimizer.load_state_dict(ddp_state_dict)

                # - run one step, and check that the models are still the same
                check_step()
 def get_file_path(self) -> str:
     paths = [tempfile.mkdtemp()] if dist.get_rank() == 0 else [None]
     dist.broadcast_object_list(paths)
     return paths[0]
            def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer", torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(params=model.parameters(), optim=optimizer, lr=1e-3)
                sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3)
                ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True)

                def check_same_model_params():
                    for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
                        for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                            assert torch.allclose(
                                p, ddp_p, atol=1e-3
                            ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}"

                    for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
                        assert torch.allclose(
                            b, ddp_b
                        ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer"

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params()

                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params()

                # The models should stay the same in between the ranks
                for i in range(20):
                    check_step()

                # Check that the checkpoints are compatible
                reference_rank = 0
                # - get states
                ddp_state_dict = ddp_optimizer.state_dict()
                sharded_optimizer.consolidate_state_dict(recipient_rank=reference_rank)
                sharded_optim_state_dict = [sharded_optimizer.state_dict() if self.rank == reference_rank else {}]
                dist.broadcast_object_list(sharded_optim_state_dict, src=reference_rank, group=dist.group.WORLD)
                sharded_optim_state_dict = sharded_optim_state_dict[0]

                # - cross load the states
                # run one step and check that the models are still the same
                ddp_state_dict_ref = copy.deepcopy(ddp_state_dict)  # OSS will remove some states
                ddp_optimizer.load_state_dict(sharded_optim_state_dict)  # mixup on purpose !
                sharded_optimizer.load_state_dict(ddp_state_dict)
                check_step()

                #  - self load, rewind, check no problem
                # run one step and check that the models are still the same
                ddp_optimizer.load_state_dict(ddp_state_dict_ref)
                sharded_optimizer.load_state_dict(sharded_optim_state_dict)
                check_step()