def to_shape(self, image, bound='zero'): """Crop/Pad a volume to match a target shape Parameters ---------- image : (channel, *spatial) tensor Input image bound : str, default='zero' Method to fill out-of-bounds. Returns ------- image : (channel, *shape) tensor Cropped/padded image """ oshape = image.shape[1:] if self.shape: oshape = (*image.shape[:-len(self.shape)], *self.shape) return utils.ensure_shape(image, oshape, mode=bound, side='both') if self.shape_min: shape_min = py.ensure_list(self.shape_min, len(oshape)) oshape = [max(s, mn) for s, mn in zip(oshape, shape_min)] if self.shape_max: shape_max = py.ensure_list(self.shape_max, len(oshape)) oshape = [min(s, mx) for s, mx in zip(oshape, shape_max)] if self.shape_mult: shape_mult = py.ensure_list(self.shape_mult, len(oshape)) oshape = [(s // m) * m for s, m in zip(oshape, shape_mult)] oshape = (*image.shape[:-len(oshape)], *oshape) return utils.ensure_shape(image, oshape, mode=bound, side='both')
def forward(self, x): shape = x.shape[2:] dim = len(shape) pshape = [x+(k-x%k) for x,k in zip(shape,self.kernel)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, self.kernel, collapse=True) x = x[:, :, torch.randperm(x.shape[2])] x = utils.fold(x, dim=dim, stride=self.kernel, collapsed=True, shape=pshape) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def forward(self, x): shape = x.shape[2:] dim = len(shape) pshape = [x+(k-x%k) for x,k in zip(shape,self.kernel)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, self.kernel, collapse=True) for n in range(self.nb_swap): i1, i2 = torch.randint(low=0, high=x.shape[2]-1, size=(2,)).tolist() x[:,:,i1], x[:,:,i2] = x[:,:,i2], x[:,:,i1] x = utils.fold(x, dim=dim, stride=self.kernel, collapsed=True, shape=pshape) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def forward(self, x): dim = x.dim() - 2 backend = utils.backend(x) kernel_exp = utils.make_vector(self.kernel_exp, dim, **backend) kernel_scale = utils.make_vector(self.kernel_scale, dim, **backend) kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)] shape = x.shape[2:] kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)] pshape = [x+(k-x%k) for x,k in zip(shape,kernel)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, kernel, collapse=True) x = x[:, :, torch.randperm(x.shape[2])] x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def forward(self, x): dim = x.dim() - 2 backend = utils.backend(x) kernel_exp = utils.make_vector(self.kernel_exp, dim, **backend) kernel_scale = utils.make_vector(self.kernel_scale, dim, **backend) shape = x.shape[2:] for n in range(self.nb_drop): kernel = [self.kernel(k_e, k_s).sample() for k_e,k_s in zip(kernel_exp, kernel_scale)] kernel = [torch.clamp(k, min=4, max=shape[i]).int().item() for i,k in enumerate(kernel)] pshape = [x+(k-x%k) for x,k in zip(shape,kernel)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, kernel, collapse=True) i1 = torch.randint(low=0, high=x.shape[2]-1, size=(1,)).item() x[:,:,i1] = 0 x = utils.fold(x, dim=dim, stride=kernel, collapsed=True, shape=pshape) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def forward(self, x, model, **fwdargs): shape = x.shape[2:] dim = len(shape) if isinstance(self.patch_size, int): patch_size = [self.patch_size] * dim else: patch_size = self.patch_size if isinstance(self.stride, int): stride = [self.stride] * dim else: stride = self.stride pshape = [x+(k-x%s) for x,k,s in zip(shape,patch_size,stride)] x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(pshape)) x = utils.unfold(x, kernel_size=self.patch_size, stride=self.stride, collapse=True) x = torch.split(x, 1, dim=2) x = [x_.reshape(tuple(x_.shape[:2])+tuple(x_.shape[3:])) for x_ in x] x = [model(x_, **fwdargs) for x_ in x] x = [x_.unsqueeze(dim=2) for x_ in x] x = torch.cat(x, dim=2) x = utils.fold(x, dim=dim, stride=self.stride, collapsed=True, shape=pshape, reduction=self.reduction) x = utils.ensure_shape(x, (x.shape[0],x.shape[1],) + tuple(shape)) return x
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict Returns ------- field : (batch, channel, *shape) tensor Generated random field """ # get arguments shape = overload.get('shape', self.shape) mean = overload.get('mean', self.mean) amplitude = overload.get('amplitude', self.amplitude) fwhm = overload.get('fwhm', self.fwhm) channel = overload.get('channel', self.channel) basis = overload.get('basis', self.basis) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) # sample if parameters are callable mean = mean() if callable(mean) else mean amplitude = amplitude() if callable(amplitude) else amplitude fwhm = fwhm() if callable(fwhm) else fwhm # device/dtype mean = torch.as_tensor(mean, dtype=dtype, device=device) amplitude = torch.as_tensor(amplitude, dtype=dtype, device=device) fwhm = torch.as_tensor(fwhm, dtype=dtype, device=device) # reshape nb_dim = len(shape) full_shape = [batch, channel, *shape] mean = mean.expand(full_shape) amplitude = amplitude.expand(full_shape) fwhm = fwhm.expand([batch, channel, nb_dim]) conv = torch.nn.functional.conv1d if nb_dim == 1 else \ torch.nn.functional.conv2d if nb_dim == 2 else \ torch.nn.functional.conv3d if nb_dim == 3 else None # convert SE parameters to noise/kernel parameters sigma_se = fwhm / math.sqrt(8 * math.log(2)) sigma_se = unsqueeze(sigma_se.prod(dim=-1), dim=-1, ndim=nb_dim) amplitude = amplitude * (2 * pi)**(nb_dim / 4) * sigma_se.sqrt() fwhm = fwhm * math.sqrt(2) # smooth samples_b = [] for b in range(batch): samples_c = [] for c in range(channel): kernel = smooth('gauss', fwhm[b, c], basis=basis, device=device, dtype=dtype) # compute input shape pad_shape = [ shape[d] + kernel[d].shape[d + 2] - 1 for d in range(nb_dim) ] mean1 = ensure_shape(mean[b, c], pad_shape, mode='reflect2', side='both') amplitude1 = ensure_shape(amplitude[b, c], pad_shape, mode='reflect2', side='both') # generate sample sample = torch.distributions.Normal(mean1, amplitude1).sample() sample = sample[None, None, ...] # convolve for ker in kernel: sample = conv(sample, ker) samples_c.append(sample) samples_b.append(torch.cat(samples_c, dim=1)) sample = torch.cat(samples_b, dim=0) return sample