Ejemplo n.º 1
0
def test_max_size():
    with torch.no_grad():  # remove a warning
        param = torch.rand((20, 30), requires_grad=True)
        param.grad = torch.rand(20, 30)

        bucket = GradBucket(5, param.dtype, param.device, -1)
        with pytest.raises(AssertionError):
            bucket.add_grad(param)
Ejemplo n.º 2
0
def test_memory_leak():
    with torch.no_grad():  # remove a warning
        param = torch.rand((2, 3), requires_grad=True)
        param.grad = torch.rand(2, 3)

        bucket = GradBucket(300, param.dtype, param.device, -1)
        bucket.add_grad(param)
        bucket.shrink()

        assert len(bucket.buffer.storage()) == 6
Ejemplo n.º 3
0
def test_grad_values_conserved():
    with torch.no_grad():  # remove a warning
        param = torch.rand((2, 3), requires_grad=True)
        param.grad = torch.rand(2, 3)

        bucket = GradBucket(10, param.dtype, param.device, -1)
        param_ = param.clone()

        bucket.add_grad(param_)
        torch.allclose(param.grad, param_.grad)
Ejemplo n.º 4
0
def test_memory_leak():
    with torch.no_grad():  # remove a warning
        param = torch.rand((2, 3), requires_grad=True)
        param.grad = torch.rand(2, 3)

        bucket = GradBucket(300, param.dtype, param.device, -1)
        bucket.add_grad(param)
        bucket.shrink()

        storage = bucket.buffer.storage()
        # See https://github.com/pytorch/pytorch/pull/59671/
        if hasattr(storage, "nbytes"):
            assert storage.nbytes() == 6 * bucket.buffer.element_size()
        else:
            assert len(storage) == 6
Ejemplo n.º 5
0
    def _setup_bucket_strategy(self) -> None:
        """Devise a bucketing strategy on a per-rank ownership level.
        These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
        This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
        """

        with profiler.record_function("fairscale::sdp::setup_buckets"):
            if not self._use_buckets:
                return

            # Devise the bucketing strategy. Parameters are already sorted, in that:
            # - these are only the trainable parameters, so they should produce grads
            # - they are sorted by increasing size
            self._buckets = {}
            self._should_bucket_grad = [False for _ in self._trainable_params]

            for i, param in enumerate(self._trainable_params):
                device = param.device
                dst_rank = self._trainable_param_to_rank[param]

                if param.device not in self._buckets.keys():
                    self._buckets[param.device] = {}

                if dst_rank not in self._buckets[param.device].keys():
                    self._buckets[param.device][dst_rank] = GradBucket(
                        self._buffer_max_size,
                        dtype=param.dtype,
                        device=param.device,
                        destination=self._local_to_global_rank[dst_rank],
                    )

                # Criteria to decide whether this parameter is to be bucketed or not:
                # - enough room in the bucket
                if self._buckets[device][dst_rank].can_add_grad_view(param):
                    self._buckets[device][dst_rank].add_grad(param)
                    self._should_bucket_grad[i] = True

            self._bucket_list = list(
                chain(*[
                    self._buckets[device].values()
                    for device in self._buckets.keys()
                ]))

            # Resize the buckets to remove lost space in the end
            for bucket in self._bucket_list:
                bucket.shrink()
Ejemplo n.º 6
0
def test_collapse():
    with torch.no_grad():  # remove a warning
        size = (5, 6)
        param = torch.rand(size, requires_grad=True)
        param.grad = torch.rand(size)

        bucket = GradBucket(300, param.dtype, param.device, -1)
        bucket.add_grad(param)
        bucket.shrink()
        bucket.collapse()

        assert bucket.buffer.numel() == 0
        assert param.grad is None
        bucket.rebuild()

        assert param.grad is not None
        torch.allclose(param.grad, torch.zeros(size))