Exemplo n.º 1
0
    def to_shape(self, image, bound='zero'):
        """Crop/Pad a volume to match a target shape

        Parameters
        ----------
        image : (channel, *spatial) tensor
            Input image
        bound : str, default='zero'
            Method to fill out-of-bounds.

        Returns
        -------
        image : (channel, *shape) tensor
            Cropped/padded image

        """
        oshape = image.shape[1:]
        if self.shape:
            oshape = (*image.shape[:-len(self.shape)], *self.shape)
            return utils.ensure_shape(image, oshape, mode=bound, side='both')
        if self.shape_min:
            shape_min = py.ensure_list(self.shape_min, len(oshape))
            oshape = [max(s, mn) for s, mn in zip(oshape, shape_min)]
        if self.shape_max:
            shape_max = py.ensure_list(self.shape_max, len(oshape))
            oshape = [min(s, mx) for s, mx in zip(oshape, shape_max)]
        if self.shape_mult:
            shape_mult = py.ensure_list(self.shape_mult, len(oshape))
            oshape = [(s // m) * m for s, m in zip(oshape, shape_mult)]
        oshape = (*image.shape[:-len(oshape)], *oshape)
        return utils.ensure_shape(image, oshape, mode=bound, side='both')
Exemplo n.º 2
0
 def forward(self, x):
     shape = x.shape[2:]
     dim = len(shape)
     pshape = [x+(k-x%k) for x,k in zip(shape,self.kernel)]
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
     x = utils.unfold(x, self.kernel, collapse=True)
     x = x[:, :, torch.randperm(x.shape[2])]
     x = utils.fold(x, dim=dim, stride=self.kernel, collapsed=True, shape=pshape)
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
     return x
Exemplo n.º 3
0
 def forward(self, x):
     shape = x.shape[2:]
     dim = len(shape)
     pshape = [x+(k-x%k) for x,k in zip(shape,self.kernel)]
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
     x = utils.unfold(x, self.kernel, collapse=True)
     for n in range(self.nb_swap):
         i1, i2 = torch.randint(low=0, high=x.shape[2]-1, size=(2,)).tolist()
         x[:,:,i1], x[:,:,i2] = x[:,:,i2], x[:,:,i1]
     x = utils.fold(x, dim=dim, stride=self.kernel, collapsed=True, shape=pshape)
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
     return x
Exemplo n.º 4
0
    def forward(self, x):
        dim = x.dim() - 2
        backend = utils.backend(x)
        kernel_exp = utils.make_vector(self.kernel_exp, dim,
                                           **backend)
        kernel_scale = utils.make_vector(self.kernel_scale, dim,
                                             **backend)

        kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)]
        shape = x.shape[2:]
        kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)]
        pshape = [x+(k-x%k) for x,k in zip(shape,kernel)]
        x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
        x = utils.unfold(x, kernel, collapse=True)
        x = x[:, :, torch.randperm(x.shape[2])]
        x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape)
        x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
        return x
Exemplo n.º 5
0
    def forward(self, x):
        dim = x.dim() - 2
        backend = utils.backend(x)
        kernel_exp = utils.make_vector(self.kernel_exp, dim, **backend)
        kernel_scale = utils.make_vector(self.kernel_scale, dim, **backend)

        shape = x.shape[2:]
        for n in range(self.nb_drop):
            kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)]
            kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)]
            pshape = [x+(k-x%k) for x,k in zip(shape,kernel)]
            x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
            x = utils.unfold(x, kernel, collapse=True)
            i1 = torch.randint(low=0, high=x.shape[2]-1, size=(1,)).item()
            x[:,:,i1] = 0
            x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape)
            x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
        return x
Exemplo n.º 6
0
 def forward(self, x, model, **fwdargs):
     shape = x.shape[2:]
     dim = len(shape)
     if isinstance(self.patch_size, int):
         patch_size = [self.patch_size] * dim
     else:
         patch_size = self.patch_size
     if isinstance(self.stride, int):
         stride = [self.stride] * dim
     else:
         stride = self.stride
     pshape = [x+(k-x%s) for x,k,s in zip(shape,patch_size,stride)]
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape))
     x = utils.unfold(x, kernel_size=self.patch_size, stride=self.stride, collapse=True)
     x = torch.split(x, 1, dim=2)
     x = [x_.reshape(tuple(x_.shape[:2])+tuple(x_.shape[3:])) for x_ in x]
     x = [model(x_, **fwdargs) for x_ in x]
     x = [x_.unsqueeze(dim=2) for x_ in x]
     x = torch.cat(x, dim=2)
     x = utils.fold(x, dim=dim, stride=self.stride, collapsed=True, shape=pshape, reduction=self.reduction)
     x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape))
     return x
Exemplo n.º 7
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        mean = overload.get('mean', self.mean)
        amplitude = overload.get('amplitude', self.amplitude)
        fwhm = overload.get('fwhm', self.fwhm)
        channel = overload.get('channel', self.channel)
        basis = overload.get('basis', self.basis)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)

        # sample if parameters are callable
        mean = mean() if callable(mean) else mean
        amplitude = amplitude() if callable(amplitude) else amplitude
        fwhm = fwhm() if callable(fwhm) else fwhm

        # device/dtype
        mean = torch.as_tensor(mean, dtype=dtype, device=device)
        amplitude = torch.as_tensor(amplitude, dtype=dtype, device=device)
        fwhm = torch.as_tensor(fwhm, dtype=dtype, device=device)

        # reshape
        nb_dim = len(shape)
        full_shape = [batch, channel, *shape]
        mean = mean.expand(full_shape)
        amplitude = amplitude.expand(full_shape)
        fwhm = fwhm.expand([batch, channel, nb_dim])

        conv = torch.nn.functional.conv1d if nb_dim == 1 else \
               torch.nn.functional.conv2d if nb_dim == 2 else \
               torch.nn.functional.conv3d if nb_dim == 3 else None

        # convert SE parameters to noise/kernel parameters
        sigma_se = fwhm / math.sqrt(8 * math.log(2))
        sigma_se = unsqueeze(sigma_se.prod(dim=-1), dim=-1, ndim=nb_dim)
        amplitude = amplitude * (2 * pi)**(nb_dim / 4) * sigma_se.sqrt()
        fwhm = fwhm * math.sqrt(2)

        # smooth
        samples_b = []
        for b in range(batch):
            samples_c = []
            for c in range(channel):
                kernel = smooth('gauss',
                                fwhm[b, c],
                                basis=basis,
                                device=device,
                                dtype=dtype)

                # compute input shape
                pad_shape = [
                    shape[d] + kernel[d].shape[d + 2] - 1
                    for d in range(nb_dim)
                ]
                mean1 = ensure_shape(mean[b, c],
                                     pad_shape,
                                     mode='reflect2',
                                     side='both')
                amplitude1 = ensure_shape(amplitude[b, c],
                                          pad_shape,
                                          mode='reflect2',
                                          side='both')

                # generate sample
                sample = torch.distributions.Normal(mean1, amplitude1).sample()
                sample = sample[None, None, ...]

                # convolve
                for ker in kernel:
                    sample = conv(sample, ker)

                samples_c.append(sample)

            samples_b.append(torch.cat(samples_c, dim=1))

        sample = torch.cat(samples_b, dim=0)

        return sample