def test_max_size(): with torch.no_grad(): # remove a warning param = torch.rand((20, 30), requires_grad=True) param.grad = torch.rand(20, 30) bucket = GradBucket(5, param.dtype, param.device, -1) with pytest.raises(AssertionError): bucket.add_grad(param)
def test_memory_leak(): with torch.no_grad(): # remove a warning param = torch.rand((2, 3), requires_grad=True) param.grad = torch.rand(2, 3) bucket = GradBucket(300, param.dtype, param.device, -1) bucket.add_grad(param) bucket.shrink() assert len(bucket.buffer.storage()) == 6
def test_grad_values_conserved(): with torch.no_grad(): # remove a warning param = torch.rand((2, 3), requires_grad=True) param.grad = torch.rand(2, 3) bucket = GradBucket(10, param.dtype, param.device, -1) param_ = param.clone() bucket.add_grad(param_) torch.allclose(param.grad, param_.grad)
def test_memory_leak(): with torch.no_grad(): # remove a warning param = torch.rand((2, 3), requires_grad=True) param.grad = torch.rand(2, 3) bucket = GradBucket(300, param.dtype, param.device, -1) bucket.add_grad(param) bucket.shrink() storage = bucket.buffer.storage() # See https://github.com/pytorch/pytorch/pull/59671/ if hasattr(storage, "nbytes"): assert storage.nbytes() == 6 * bucket.buffer.element_size() else: assert len(storage) == 6
def _setup_bucket_strategy(self) -> None: """Devise a bucketing strategy on a per-rank ownership level. These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case. This method can be a slow for big models, but it it not typically called often (not for every forward for instance) """ with profiler.record_function("fairscale::sdp::setup_buckets"): if not self._use_buckets: return # Devise the bucketing strategy. Parameters are already sorted, in that: # - these are only the trainable parameters, so they should produce grads # - they are sorted by increasing size self._buckets = {} self._should_bucket_grad = [False for _ in self._trainable_params] for i, param in enumerate(self._trainable_params): device = param.device dst_rank = self._trainable_param_to_rank[param] if param.device not in self._buckets.keys(): self._buckets[param.device] = {} if dst_rank not in self._buckets[param.device].keys(): self._buckets[param.device][dst_rank] = GradBucket( self._buffer_max_size, dtype=param.dtype, device=param.device, destination=self._local_to_global_rank[dst_rank], ) # Criteria to decide whether this parameter is to be bucketed or not: # - enough room in the bucket if self._buckets[device][dst_rank].can_add_grad_view(param): self._buckets[device][dst_rank].add_grad(param) self._should_bucket_grad[i] = True self._bucket_list = list( chain(*[ self._buckets[device].values() for device in self._buckets.keys() ])) # Resize the buckets to remove lost space in the end for bucket in self._bucket_list: bucket.shrink()
def test_collapse(): with torch.no_grad(): # remove a warning size = (5, 6) param = torch.rand(size, requires_grad=True) param.grad = torch.rand(size) bucket = GradBucket(300, param.dtype, param.device, -1) bucket.add_grad(param) bucket.shrink() bucket.collapse() assert bucket.buffer.numel() == 0 assert param.grad is None bucket.rebuild() assert param.grad is not None torch.allclose(param.grad, torch.zeros(size))