Ejemplo n.º 1
0
def test_param_values_conserved():
    param = torch.rand((2, 3))

    bucket = ParamBucket(10, param.dtype, param.device)
    param_ = param.clone()

    bucket.add_param(param_)
    torch.allclose(param, param_)
Ejemplo n.º 2
0
    def _setup_flat_buffers(self) -> None:
        """Make all params which are on the same device and tied to the same rank views of a single buffer.
        This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
        `refresh_trainability` is called.
        """

        for device, per_rank_params in self._per_device_params.items():
            # Only wipe the existing buckets if there are none
            # (could be that this is called twice, when trainability changes)
            if device not in self.buckets.keys():
                self.buckets[device] = {}

            # Make parameters a view of the bucket
            for dst_rank, params in enumerate(per_rank_params):
                if len(params) > 0:

                    # Clone the non-trainable params, if in a bucket it will get destroyed
                    for param in filter(lambda x: not x.requires_grad, params):
                        param.data = param.data.detach().clone()

                    # Merge all the trainable params in a single bucket
                    trainable_params = list(
                        filter(lambda x: x.requires_grad, params))
                    if trainable_params:
                        buffer_size = sum(
                            map(lambda x: x.numel(), trainable_params))
                        bucket = ParamBucket(size=buffer_size,
                                             dtype=trainable_params[0].dtype,
                                             device=device)

                        for param in trainable_params:
                            bucket.add_param(param)

                        self.buckets[device][dst_rank] = bucket

        # Clear the buffer keys which are not in use anymore (could be that the devices changed)
        devices_in_use = list(self._per_device_params.keys())
        devices_to_pop = list(
            filter(lambda x: x not in devices_in_use, self.buckets.keys()))
        for d in devices_to_pop:
            self.buckets.pop(d)
Ejemplo n.º 3
0
def test_double_check_int():
    param = torch.rand((5, 6))

    bucket = ParamBucket(300, param.dtype, param.device)
    bucket.add_param(param)

    with pytest.raises(AssertionError):
        bucket.add_param(param)
Ejemplo n.º 4
0
def test_type_change():
    size = (5, 6)
    param = torch.rand(size, requires_grad=True)
    param_ = param.clone()

    bucket = ParamBucket(30, param.dtype, param.device)
    bucket.add_param(param)

    # Move the bucket to fp16 and back
    bucket.to(dtype=torch.float16, device=param.device)
    bucket.to(dtype=torch.float32, device=param.device, keep_param_alignment=True)

    # Same with the reference tensor
    param_.to(dtype=torch.float16)
    param_.to(dtype=torch.float32)

    torch.allclose(param, param_)
Ejemplo n.º 5
0
def test_max_size():
    param = torch.rand((20, 30))

    bucket = ParamBucket(5, param.dtype, param.device)
    with pytest.raises(AssertionError):
        bucket.add_param(param)