示例#1
0
    def __init__(
        self,
        params: _params_t,
        optim: Type[Optimizer] = SGD,
        group: Optional[Any] = None,
        broadcast_buffer_size: int = -1,
        broadcast_fp16: bool = False,
        **default: Any,
    ):

        # Hold all the model params in the root .param_groups
        self.in_super_constructor = True
        super().__init__(params, default)
        self.in_super_constructor = False

        # Partition information. lazy evaluation, computed when requested
        self.__per_device_params: Dict[
            torch.device,
            List[List[Parameter]]] = OrderedDict()  # device, rank, params
        self.__param_rank: Dict[torch.Tensor, int] = {}
        self._partition_parameters: List[List[dict]] = []
        self.__param_to_index: Dict[int, int] = {}
        self.__local_params: Optional[List[torch.Tensor]] = None

        # Default empty values + immutables
        self._optim_defaults = default
        self._optim_constructor = optim

        self.group = group if group is not None else dist.group.WORLD
        self.world_size = dist.get_world_size(self.group)
        self.backend = dist.get_backend(self.group)
        self.rank = dist.get_rank(self.group)
        self.global_rank = get_global_rank(self.group, self.rank)
        self._local_to_global_rank = [
            get_global_rank(self.group, i) for i in range(self.world_size)
        ]

        self.broadcast_fp16 = broadcast_fp16
        self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {}
        self._all_states: List[Dict[str, Any]] = [
        ]  # Optional consolidated optimizer state
        self._default_device = torch.device("cpu")

        # Setup everything which is related to the parameters to be trained
        # (partition and optimizer for the shard)
        self.refresh_trainable()
示例#2
0
    def __init__(
        self,
        module: nn.Module,
        sharded_optimizer: Union[OSS, List[OSS]],
        process_group: Any = None,
        broadcast_buffers: bool = True,
        sync_models_at_startup: bool = True,
        reduce_buffer_size: int = 2**23,
        auto_refresh_trainable: bool = True,
        reduce_fp16: bool = False,
    ):
        super().__init__()

        # This field needs to be exposed to insure interface parity with DDP
        self.module = module

        self._sharded_optimizers = [
            sharded_optimizer
        ] if not isinstance(sharded_optimizer, list) else sharded_optimizer
        self._enable_broadcast_buffers = broadcast_buffers
        self._auto_refresh_trainable = auto_refresh_trainable
        self._reduce_fp16 = reduce_fp16
        if reduce_buffer_size > 0 and reduce_fp16:
            self._reduce_fp16 = False
            logging.warning(
                "fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
            )

        # Handle a no_sync() context which prevents the gradient synchronization,
        # accumulate in place
        self._should_accumulate_grads = False
        self._accumulate_grads_flipped = False

        # Communication related attributes
        self._process_group = process_group if process_group is not None else dist.group.WORLD
        self._backend = dist.get_backend(self._process_group)
        self._world_size_scaling = 1.0 / dist.get_world_size(
            self._process_group)  # > 0
        self._reference_global_rank = get_global_rank(
            self._process_group, 0)  # picking rank 0 as the reference
        self._rank = dist.get_rank(self._process_group)
        self._global_rank = get_global_rank(self._process_group, self._rank)
        self._local_to_global_rank = [
            get_global_rank(self._process_group, i)
            for i in range(dist.get_world_size(self._process_group))
        ]

        # Expose some of the PytorchDDP attributes, some frameworks rely on them.
        # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
        # device_id related logic is not present, this is not handled
        devices = {p.device for p in self.module.parameters()}
        self.is_multi_device_module = len(devices) > 1

        distinct_device_types = {
            p.device.type
            for p in self.module.parameters()
        }
        assert len(distinct_device_types) == 1, (
            "ShardedDataParallel's input module must be on "
            "the same type of devices, but input module parameters are located on {} different device types."
        ).format(distinct_device_types)
        self.device_type = list(distinct_device_types)[0]

        # Scafolding to be able to reduce the grads during the BW pass
        # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks

        # - we build an iterator which goes through all the parameters involved globally
        self._all_params = list(
            chain(*[
                sum([sum(p, [])
                     for p in optim._per_device_params.values()], [])
                for optim in self._sharded_optimizers
            ]))
        self._trainable_params: List[torch.Tensor] = []
        self._grad_to_be_reduced: List[bool] = []
        self._trainable_param_to_rank: Dict[torch.Tensor, int] = {}
        self._reference_trainable_mask = list(map(_trainable,
                                                  self._all_params))

        # - setup buckets and tensor views
        model_size = sum([p.numel() for p in self.module.parameters()])
        self._buffer_max_size = min(reduce_buffer_size, model_size)

        if dist.get_world_size(self._process_group) == 1:
            self._buffer_max_size = 0
            logging.info(
                "Training is not really distributed, single rank. Deactivating buckets"
            )

        logging.info(
            "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
            .format(self._buffer_max_size / 2**20, model_size / 2**20))
        self._use_buckets = self._buffer_max_size > 0

        self._buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
        self._should_bucket_grad: List[bool] = []
        self._bucket_list: List[GradBucket] = []

        # - setup backward hooks which will be called by Torch's autograd in due time
        self._grad_accs: List[Callable] = []
        self._grad_hooks: List[Any] = []
        self._manual_reduce: List[Callable] = []

        # passing a handle to torch.nn.SyncBatchNorm layer
        self._passing_sync_batchnorm_handle(self.module)

        # Make sure that all ranks start with the same model
        if sync_models_at_startup:
            self._sync_params_and_buffers()

        self._work_handles: Deque[Workhandle] = deque()
        self._bucket_flush_callback_set = False