def shape(self, x): """Output shape of the equivalent forward call. Parameters ---------- x : (b, c, **spatial) tensor or sequence[int] A tensor or its shape. Returns ------- shape : tuple[int] (b, c, **spatial_out) In each dimension, the output shape is `(size-offset)//stride`. """ if torch.is_tensor(x): x = x.shape x = list(x) dim = len(x) - 2 offset = py.make_list(self.offset, dim) stride = py.make_list(self.stride, dim) x = x[:2] + [(xx-o)//s for xx, o, s in zip(x[2:], offset, stride)] return tuple(x)
def shape(self, x, output_padding=None, output_shape=None): """ Parameters ---------- x : (batch, channel, *in_spatial) tensor or sequence[int] output_padding : [sequence of] int, default=self.output_padding output_shape : [sequence of] int, default=self.output_shape Returns ------- shape : tuple[int] """ if torch.is_tensor(x): x = x.shape x = list(x) if output_padding is not None and output_shape is not None: raise ValueError('Only one of `output_padding` or `output_shape` ' 'should be provided.') elif output_padding is None and output_shape is None: output_padding = self.output_padding output_shape = self.output_shape dim = len(x) - 2 offset = py.make_list(self.offset, dim) stride = py.make_list(self.stride, dim) output_padding = py.make_list(output_padding, dim) if output_shape: output_shape = py.make_list(output_shape, dim) else: output_shape = [sz*st + o + p for sz, st, o, p in zip(x[2:], stride, offset, output_padding)] return (*x[:2], *output_shape)
def _patch(patch, affine, shape, level): """Compute the patch size in voxels""" dim = affine.shape[-1] - 1 patch = py.make_list(patch) unit = 'pct' if isinstance(patch[-1], str): *patch, unit = patch patch = py.make_list(patch, dim) unit = unit.lower() if unit[0] == 'v': # voxels patch = [float(p) / 2**level for p in patch] elif unit in ('m', 'mm', 'cm', 'um'): # assume RAS orientation factor = (1e-3 if unit == 'um' else 1e1 if unit == 'cm' else 1e3 if unit == 'm' else 1) affine_ras = spatial.affine_reorient(affine, layout='RAS') vx_ras = spatial.voxel_size(affine_ras).tolist() patch = [factor * p / v for p, v in zip(patch, vx_ras)] patch = _ras_to_layout(patch, affine) elif unit[0] in 'p%': # percentage of shape patch = [0.01 * p * s for p, s in zip(patch, shape)] else: raise ValueError('Unknown patch unit:', unit) # round down to zero small patch sizes patch = [0 if p < 1e-3 else p for p in patch] return patch
def _get_default_native(affines, shapes): """Get default native space Parameters ---------- affines : [sequence of] (4, 4) tensor_like or None shapes : [sequence of] (3,) tensor_like Returns ------- affines : (N, 4, 4) tensor shapes : (N, 3) tensor """ shapes = utils.as_tensor(shapes).reshape([-1, 3]) shapes = shapes.unbind(dim=0) if torch.is_tensor(affines): affines = affines.reshape([-1, 4, 4]) affines = affines.unbind(dim=0) shapes = py.make_list(shapes, max(len(shapes), len(affines))) affines = py.make_list(affines, max(len(shapes), len(affines))) # default affines affines = [spatial.affine_default(shape) if affine is None else affine for shape, affine in zip(shapes, affines)] affines = utils.as_tensor(affines) shapes = utils.as_tensor(shapes) affines, shapes = utils.to_max_device(affines, shapes) return affines, shapes
def get_kernel(kernel, affine, shape, level): """Convert the provided kernel size (RAS mm or pct) to native voxels""" dim = affine.shape[-1] - 1 kernel = py.make_list(kernel) unit = 'pct' if isinstance(kernel[-1], str): *kernel, unit = kernel kernel = py.make_list(kernel, dim) unit = unit.lower() if unit[0] == 'v': # voxels kernel = [p / 2**level for p in kernel] elif unit in ('m', 'mm', 'cm', 'um'): # assume RAS orientation factor = (1e-3 if unit == 'um' else 1e1 if unit == 'cm' else 1e3 if unit == 'm' else 1) affine_ras = spatial.affine_reorient(affine, layout='RAS') vx_ras = spatial.voxel_size(affine_ras).tolist() kernel = [factor * p / v for p, v in zip(kernel, vx_ras)] kernel = ras_to_layout(kernel, affine) elif unit[0] in 'p%': # percentage of shape kernel = [0.01 * p * s for p, s in zip(kernel, shape)] else: raise ValueError('Unknown patch unit:', unit) # ensure patch size is an integer >= 2 (else, no gradients) kernel = list(map(lambda x: max(int(pymath.ceil(x)), 2), kernel)) return kernel
def __init__(self, dat, affine=None, dim=None, mask=None, bound='dct2', extrapolate=False, **backend): # I don't call super().__init__() on purpose if torch.is_tensor(affine): affine = [affine] * len(dat) elif affine is None: affine = [] for dat1 in dat: dim1 = dim or dat1.dim if callable(dim1): dim1 = dim1() if hasattr(dat1, 'affine'): aff1 = dat1.affine else: shape1 = dat1.shape[-dim1:] aff1 = spatial.affine_default(shape1, **utils.backend(dat1)) affine.append(aff1) affine = py.make_list(affine, len(dat)) mask = py.make_list(mask, len(dat)) self._dat = [] for dat1, aff1, mask1 in zip(dat, affine, mask): if not isinstance(dat1, Image): dat1 = Image(dat1, aff1, mask=mask1, dim=dim, bound=bound, extrapolate=extrapolate) self._dat.append(dat1)
def _affine(self): """Affine orientation matrix of a series+level""" # TODO: I don't know yet how we should use GeoTiff to encode # affine matrices. In the matrix/zooms, their voxels are ordered # as [x, y, z] even though their dimensions in the returned array # are ordered as [Z, Y, X]. If we want to keep the same convention # as nitorch, I need to permute the matrix/zooms. if '_affine' not in self._cache: with self.tiffobj() as tiff: omexml = tiff.ome_metadata geotags = tiff.geotiff_metadata or {} zooms, units, axes = ome_zooms(omexml, self.series) if zooms: # convert to mm + drop non-spatial zooms units = [parse_unit(u) for u in units] zooms = [z * (f / 1e-3) for z, (f, type) in zip(zooms, units) if type in ('m', 'pixel')] if 'ModelPixelScaleTag' in geotags: warn("Both OME and GeoTiff pixel scales are present: " "{} vs {}. Using OME." .format(zooms, geotags['ModelPixelScaleTag'])) elif 'ModelPixelScaleTag' in geotags: zooms = geotags['ModelPixelScaleTag'] axes = 'XYZ' else: zooms = 1. axes = [ax for ax in self._axes if ax in 'XYZ'] if 'ModelTransformation' in geotags: aff = geotags['ModelTransformation'] aff = torch.as_tensor(aff, dtype=torch.double).reshape(4, 4) self._cache['_affine'] = aff elif ('ModelTiepointTag' in geotags): # copied from tifffile sx, sy, sz = py.make_list(zooms, n=3) tiepoints = torch.as_tensor(geotags['ModelTiepointTag']) affines = [] for tiepoint in tiepoints: i, j, k, x, y, z = tiepoint affines.append(torch.as_tensor( [[sx, 0.0, 0.0, x - i * sx], [0.0, -sy, 0.0, y + j * sy], [0.0, 0.0, sz, z - k * sz], [0.0, 0.0, 0.0, 1.0]], dtype=torch.double)) affines = torch.stack(affines, dim=0) if len(tiepoints) == 1: affines = affines[0] self._cache['_affine'] = affines else: zooms = py.make_list(zooms, n=len(axes)) ax2zoom = {ax: zoom for ax, zoom in zip(axes, zooms)} axes = [ax for ax in self._axes if ax in 'XYZ'] shape = [shp for shp, msk in zip(self._shape, self._spatial) if msk] zooms = [ax2zoom.get(ax, 1.) for ax in axes] layout = [('R' if ax == 'Z' else 'P' if ax == 'Y' else 'S') for ax in axes] aff = affine_default(shape, zooms, layout=''.join(layout)) self._cache['_affine'] = aff return self._cache['_affine']
def shape(self, x, output_shape=None): if torch.is_tensor(x): x = x.shape x = list(x) dim = len(x) - 2 stride = py.make_list(self.stride, dim) if output_shape: output_shape = py.make_list(output_shape, dim) else: output_shape = [sz*st for sz, st in zip(x[2:], stride)] return (*x[:2], *output_shape)
def fov(self, value): if value is None: value = (None, None) min, max = value if torch.is_tensor(min): min = min.flatten().tolist() min = py.make_list(min, 3) if torch.is_tensor(max): max = max.flatten().tolist() max = py.make_list(max, 3) if any(mn is None for mn in min) or any(mx is None for mx in max): min0, max0 = self._max_fov() min = [mn or mn0 for mn, mn0 in zip(min, min0)] max = [mx or mx0 for mx, mx0 in zip(max, max0)] self._fov = (tuple(min), tuple(max))
def fov_size(self, value): if torch.is_tensor(value): value = value.flatten().tolist() value = py.make_list(value, 3) min = [i - v / 2 if v else None for i, v in zip(self._index, value)] max = [i + v / 2 if v else None for i, v in zip(self._index, value)] self.fov = [min, max]
def forward(self, *image, **overload): """ Parameters ---------- image : (batch, channel, *spatial) overload Returns ------- """ image = list(image) device = image[0].device nb_dim = image[0].dim() - 2 prob = utils.make_vector(overload.get('prob', self.prob), dtype=torch.float, device=device) dim = overload.get('dim', self.dim) dim = py.make_list(dim or range(-nb_dim, 0), nb_dim) # sample shift flip = torch.rand((nb_dim,), device=device) > (1 - prob) dim = [d for d, f in zip(dim, flip) if f] if dim: for i, img in enumerate(image): image[i] = img.flip(dim) return image[0] if len(image) == 1 else tuple(image)
def get_one_hot_map(one_hot_map, nb_classes): """Return a well-formed one-hot map""" one_hot_map = py.make_list(one_hot_map or []) if not one_hot_map: one_hot_map = list(range(1, nb_classes)) if len(one_hot_map) == nb_classes - 1: one_hot_map = [*one_hot_map, None] if len(one_hot_map) != nb_classes: raise ValueError('Number of classes in prior and map ' 'do not match: {} and {}.'.format( nb_classes, len(one_hot_map))) one_hot_map = list( map(lambda x: py.make_list(x) if x is not None else x, one_hot_map)) if sum(elem is None for elem in one_hot_map) > 1: raise ValueError('Cannot have more than one implicit class') return one_hot_map
def _composition_jac(jac, rhs, lhs=None, type='grid', identity=None, **kwargs): """Jacobian of the composition `(lhs)o(rhs)` Parameters ---------- jac : ([batch], *spatial, ndim, ndim) tensor Jacobian of input RHS transformation rhs : ([batch], *spatial, ndim) tensor RHS transformation lhs : ([batch], *spatial, ndim) tensor, default=`rhs` LHS small displacement kwargs : dict Options to ``grid_pull`` Returns ------- composed_jac : ([batch], *spatial, ndim, ndim) tensor Jacobian of composition """ if lhs is None: lhs = rhs dim = rhs.shape[-1] backend = utils.backend(rhs) typer, typel = py.make_list(type, 2) jac_left = grid_jacobian(lhs, type=typel) if typer != 'grid': if identity is None: identity = identity_grid(rhs.shape[-dim - 1:-1], **backend) rhs = rhs + identity jac_left = _pull_jac(jac_left, rhs) jac = torch.matmul(jac_left, jac) return jac
def code_to_pattern(code, kernel_size, **backend): """Convert a unique code to a sampling pattern Parameters ---------- code : int or ([*batch]) tensor[int] kernel_size : sequence of int Returns ------- pattern : ([*batch], *kernel_size) tensor[bool] """ backend.setdefault('dtype', torch.bool) if torch.is_tensor(code): backend.setdefault('device', code.device) kernel_size = py.make_list(kernel_size) def make_pattern(code): pattern = torch.zeros(kernel_size, **backend) pattern_flat = pattern.flatten() for i in range(len(pattern_flat)): pattern_flat[i] = bool((code >> i) & 1) return pattern if torch.is_tensor(code): pattern = code.new_zeros([*code.shape, *kernel_size], **backend) for code1 in code.unique(): pattern[code == code1] = make_pattern(code1) else: pattern = make_pattern(code) return pattern
def index(self, value): if torch.is_tensor(value): value = value.flatten().tolist() value = py.make_list(value, 3) value = [(mx + mn) / 2 if v is None else v for v, mn, mx in zip(value, *self.fov)] self._index = tuple(value)
def forward(self, q, k, v, **overload): """ Parameters ---------- q : (b, c, *spatial) Queries k : (b, c, *spatial) Keys v : (b, c, *spatial) Values Returns ------- x : (b, c, *spatial) """ kernel_size = overload.pop('kernel_size', self.kernel_size) stride = overload.pop('stride', self.kernel_size) padding = overload.pop('padding', self.padding) padding_mode = overload.pop('padding_mode', self.padding_mode) dim = q.dim() - 2 if padding == 'auto': k = spatial.pad_same(dim, k, kernel_size, bound=padding_mode) v = spatial.pad_same(dim, v, kernel_size, bound=padding_mode) elif padding: padding = [0] * 2 + py.make_list(padding, dim) k = utils.pad(k, padding, side='both', mode=padding_mode) v = utils.pad(v, padding, side='both', mode=padding_mode) # compute weights by query/key dot product kernel_size = py.make_list(kernel_size, dim) k = utils.unfold(k, kernel_size, stride) k = k.reshape([*k.shape[:dim + 2], -1]) k = utils.movedim(k, 1, -1) q = utils.movedim(q[..., None], 1, -1) k = math.softmax(linalg.dot(k, q), dim=-1) k = k[:, None] # add back channel dimension # compute new values by weight/value dot product v = utils.unfold(v, kernel_size, stride) v = v.reshape([*v.shape[:dim + 2], -1]) v = linalg.dot(k, v) return v
def spherical_harmonics(shape, order=2, isocenter=None, **backend): """Generate a basis of spherical harmonics on a lattice Notes ----- .. This should be checked! .. Only orders 1 and 2 implemented .. I tried to implement some sort of "circular" harmonics in dimension 2 but I don't know what I am doing. .. The basis is not orthogonal Parameters ---------- shape : sequence of int order : {1, 2}, default=2 isocenter : [sequence of] int, default=shape/2 dtype : torch.dtype, optional device : torch.device, optional Returns ------- b : (*shape, 2*order + 1) tensor Basis """ shape = py.make_list(shape) dim = len(shape) if dim not in (2, 3): raise ValueError('Dimension must be 2 or 3') if order not in (1, 2): raise ValueError('Order must be 1 or 2') if isocenter is None: isocenter = [s / 2 for s in shape] isocenter = utils.make_vector(isocenter, **backend) ramps = identity_grid(shape, **backend) for i, ramp in enumerate(ramps.unbind(-1)): ramp -= isocenter[i] ramp /= shape[i] / 2 if order == 1: return ramps # order == 2 if dim == 3: basis = [ ramps[..., 0] * ramps[..., 1], ramps[..., 0] * ramps[..., 2], ramps[..., 1] * ramps[..., 2], ramps[..., 0].square() - ramps[..., 1].square(), ramps[..., 0].square() - ramps[..., 2].square() ] return torch.stack(basis, -1) else: # basis == 2 basis = [ ramps[..., 0] * ramps[..., 1], ramps[..., 0].square() - ramps[..., 1].square() ] return torch.stack(basis, -1)
def loadseg(self, fnames, segtype='label', lookup=None, dtype=None, device=None): """Load a volume from disk Parameters ---------- fnames : str or sequence[str] segtype : [tuple] {'label', 'implicit', 'explicit'}, default='label' lookup : list of [list of] int, optional dtype : torch.dtype, optional Returns ------- dat : (channels | 1, *spatial) tensor """ insegtype, outsegtype = py.make_list(segtype, 2) sdtype = dtype if insegtype != 'label' else torch.int fnames = py.make_list(fnames) channels = [] for fname in fnames: # loop across channels dat = self.load(fname, dtype=sdtype, device=device) channels.append(dat) if len(channels) > 1: channels = torch.cat(channels, 0) else: channels = channels[0] if insegtype == 'label' and lookup: channels = utils.merge_labels(channels, lookup) if insegtype == 'label' and outsegtype != 'label': channels = utils.one_hot(channels, dim=0, implicit=outsegtype == 'implicit', implicit_index=0, dtype=dtype) if outsegtype == 'implicit': channels /= len(channels) + 1 else: channels /= len(channels) return channels
def forward(self, x): """ Parameters ---------- x : (b, c, **spatial) tensor Returns ------- x : (b, c, **spatial_out) tensor """ dim = x.dim() - 2 offset = py.make_list(self.offset, dim) stride = py.make_list(self.stride, dim) slicer = [slice(o, ((sz-o)//st)*st, st) for o, st, sz in zip(offset, stride, x.shape[2:])] slicer = [slice(None)]*2 + slicer return x[tuple(slicer)]
def __init__(self, shape, in_channels, out_channels, nb_levels=0, decoder=(32, 32, 32, 32), kernel_size=3, activation=tnn.LeakyReLU(0.2), unpool=None): """ Parameters ---------- shape : sequence[int] Output spatial shape in_channels : int Number of input channels (= meta variables) out_channels : int Number of output channels nb_levels : int, default=0 Number of levels in the decoder. If 0: directly generate the image using a dense layer. decoder : sequence[int], default=(32, 32, 32, 32) Number of features after each layers. If len(decoder) is larger than the number of levels, additional stride-1 convolutions are applied. kernel_size : [sequence of] int, default=3 activation : str or callable, default=LeakyReLU(0.2) unpool : {'conv', 'up', None}, default=None 'conv' -> 2x2x2 strided convolution (no bias, no activation) 'up' -> linear upsampling None -> use strided convolutions in the decoder """ super().__init__() shape = py.make_list(shape) dim = len(shape) small_shape = [s // 2**nb_levels for s in shape] in_feat, *decoder = decoder self.dense = Linear(in_channels, py.prod(small_shape)*in_feat) self.reshape = lambda x: x.reshape([-1, in_feat, *small_shape]) decoder, stack = decoder[:nb_levels], decoder[nb_levels:] if decoder: self.decoder = Decoder(dim, in_feat, decoder, kernel_size=kernel_size, activation=activation, unpool=unpool) in_feat = decoder[-1] else: self.decoder = lambda x: x if stack: self.stack = StackedConv(dim, in_feat, stack, kernel_size=kernel_size, activation=activation) in_feat = stack[-1] else: self.stack = lambda x: x self.final = StackedConv(dim, in_feat, out_channels, kernel_size=kernel_size, activation=None)
def forward(self, x, output_padding=None, output_shape=None): """ Parameters ---------- x : (batch, channel, *in_spatial) tensor output_padding : [sequence of] int, default=self.output_padding output_shape : [sequence of] int, default=self.output_shape Returns ------- x : (batch, channel, *out_spatial) tensor """ dim = x.dim() - 2 offset = py.make_list(self.offset, dim) stride = py.make_list(self.stride, dim) new_shape = self.shape(x, output_padding=output_padding, output_shape=output_shape) y = x.new_zeros(new_shape) if self.fill: z = utils.unfold(y, stride) x = utils.unsqueeze(x, -1, dim) slicer = [slice(o, o+sz*st) for sz, st, o in zip(x.shape[2:], stride, offset)] slicer = [slice(None)]*2 + slicer subz = z[tuple(slicer)] slicer = [slice(mx) for mx in subz.shape[2:]] slicer = [slice(None)]*2 + slicer subz.copy_(x[tuple(slicer)]) else: slicer = [slice(o, None, s) for o, s in zip(offset, stride)] slicer = [slice(None)]*2 + slicer suby = y[tuple(slicer)] slicer = [slice(mx) for mx in suby.shape[2:]] slicer = [slice(None)]*2 + slicer suby.copy_(x[tuple(slicer)]) return y
def forward(self, x, output_shape=None): if len(self) == 1: # strided conv conv, = self x = conv(x, output_shape=output_shape) else: # resize + conv resize, conv = self factor = [1/s for s in py.make_list(self.stride)] x = resize(x, output_shape=output_shape, factor=factor) x = conv(x) return x
def __init__(self, shape=None, nb_classes=None, amplitude='lognormal', amplitude_exp=5, amplitude_scale=2, fwhm='lognormal', fwhm_exp=15, fwhm_scale=5, logits=False, implicit=False, device=None, dtype=None): """ Parameters ---------- shape : sequence[int], optional Output shape nb_classes : int, optional Number of classes (excluding background) amplitude : {'normal', 'lognormal', 'uniform', 'gamma'}, default='lognormal' amplitude_exp : float or (channel,) vector_like, default=5 amplitude_scale : float or (channel,) vector_like, default=2 fwhm : {'normal', 'lognormal', 'uniform', 'gamma'}, default='lognormal' fwhm_exp : float or (channel,) vector_like, default=15 fwhm_scale : float or (channel,) vector_like, default=5 logits : bool, default=False Outputs are log-odds instead of probabilities implicit : bool, default=False Outputs have an implicit K+1-th class. device : torch.device, optional Output tensor device. dtype : torch.dtype, optional Output tensor datatype. """ super().__init__() self.logits = logits self.implicit = implicit shape = py.make_list(shape) self.field = HyperRandomFieldSpline(shape=shape, channel=nb_classes + 1, amplitude=amplitude, amplitude_exp=amplitude_exp, amplitude_scale=amplitude_scale, fwhm=fwhm, fwhm_exp=fwhm_exp, fwhm_scale=fwhm_scale, device=device, dtype=dtype)
def multitrainer(trainers, keep_on_gpu=True, parallel=False, cuda_pool=(0, )): """Train multiple models in parallel. Parameters ---------- trainers : sequence[ModelTrainer] keep_on_gpu : bool, default=True Keep all models on GPU (risk of out-of-memory) parallel : bool or int, default=False Train model in parallel. If an int, only this number of models will be trained in parallel. cuda_pool : sequence[int], default=[0] IDs of GPUs that can be used to dispatch the models. """ n = len(trainers) initial_epoch = max(min(trainer.initial_epoch for trainer in trainers), 1) nb_epoch = max(trainer.nb_epoch for trainer in trainers) trainers = tuple(trainers) if parallel: parallel = len(trainers) if parallel is True else parallel chunksize = max(len(trainers) // (2 * parallel), 1) pool = Pool(parallel) cuda_pool = py.make_list(pool, parallel) if not torch.cuda.is_available(): cuda_pool = [] if not keep_on_gpu: for trainer in trainers: trainer.to(device='cpu') # --- init --- if parallel: args = zip(trainers, [keep_on_gpu] * n, [cuda_pool] * n) trainers = list(pool.map(_init1, args, chunksize=chunksize)) else: args = zip(trainers, [keep_on_gpu] * n, [cuda_pool] * n) trainers = list(map(_init1, args)) # --- train --- for epoch in range(initial_epoch + 1, nb_epoch + 1): if parallel: args = zip(trainers, [epoch] * n, [keep_on_gpu] * n, [cuda_pool] * n) trainers = list(pool.map(_train1, args, chunksize=chunksize)) else: args = zip(trainers, [epoch] * n, [keep_on_gpu] * n, [cuda_pool] * n) trainers = list(map(_train1, args))
def _init_from_fname(new, fnames, permission='r', keep_open=False, **attributes): fnames = py.make_list(fnames) fs = [] for fname in fnames: f = io.map(fname, permission=permission, keep_open=keep_open) while f.dim < 4: f = f.unsqueeze(-1) fs += [f] fs = io.cat(fs, -1).permute([-1, 0, 1, 2]) new._init_from_mapped(fs, **attributes)
def kernel_apply(kspace, patterns, kernel_size, kernels, inplace=False): """Apply a GRAPPA kernel to an accelerated k-space All batch elements should have the same sampling pattern Parameters ---------- kspace : ([*batch], coils, *freq) Accelerated k-space patterns : (*freq) tensor[long] Code of sampling pattern about each k-space location kernel_size : sequence of int GRAPPA kernel size kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor Dictionary of GRAPPA kernels (keys are pattern codes) Returns ------- kspace : ([*batch], coils, *freq) """ ndim = patterns.dim() coils, *freq = kspace.shape[-ndim - 1:] batch = kspace.shape[:-ndim - 1] kernel_size = py.make_list(kernel_size, ndim) kspace_out = kspace if not inplace: kspace_out = kspace_out.clone() kspace = utils.pad(kspace, [(k - 1) // 2 for k in kernel_size], side='both') kspace = utils.unfold(kspace, kernel_size, stride=1) def t(x): return x.transpose(-1, -2) for code, kernel in kernels.items(): kernel = kernels[code] pattern = code_to_pattern(code, kernel_size, device=kspace.device) pattern_size = pattern.sum() mask = patterns == code kspace1 = kspace[..., mask, :, :][..., pattern] kspace1 = kspace1.transpose(-2, -3) \ .reshape([*batch, -1, coils * pattern_size]) kernel = kernel.reshape([*batch, coils, coils * pattern_size]) kspace1 = t(kspace1.matmul(t(kernel))) kspace_out[..., mask] = kspace1 return kspace_out
def _prepare_pyramid_levels(losses, opt, dim=None): """ For each loss, compute the pyramid levels of `fix` and `mov` that must be computed. """ if opt.name == 'none': return [{'fix': None, 'mov': None}] * len(losses) vxs = [] shapes = [] for loss in losses: # fixed image dat, affine = _map_image(loss.fix.files, dim) vx = dat.voxel_size[0].tolist() shape = dat.shape[:dim] if loss.fix.pad: pad = py.make_list(loss.fix.pad, len(shape)) shape = [s + 2 * p for s, p in zip(shape, pad)] vxs.append(vx) shapes.append(shape) # moving image dat, affine = _map_image(loss.mov.files, dim) vx = dat.voxel_size[0].tolist() shape = dat.shape[:dim] if loss.mov.pad: pad = py.make_list(loss.mov.pad, len(shape)) shape = [s + 2 * p for s, p in zip(shape, pad)] vxs.append(vx) shapes.append(shape) levels = _pyramid_levels(vxs, shapes, opt) levels = [{ 'fix': fix, 'mov': mov } for fix, mov in zip(levels[::2], levels[1::2])] return levels
def _softmax_fwd(input, dim=-1, implicit=False): """ SoftMax (safe). Parameters ---------- input : torch.tensor Tensor with values. dim : int, default=-1 Dimension to take softmax, defaults to last dimensions. implicit : bool or (bool, bool), default=False The first value relates to the input tensor and the second relates to the output tensor. - implicit[0] == True assumes that an additional (hidden) channel with value zero exists. - implicit[1] == True drops the last class from the softmaxed tensor. Returns ------- Z : torch.tensor Soft-maxed tensor with values. """ implicit_in, implicit_out = py.make_list(implicit, 2) maxval, _ = torch.max(input, dim=dim, keepdim=True) if implicit_in: maxval.clamp_min_(0) # don't forget the class full of zeros input = input.clone().sub_(maxval).exp_() sumval = torch.sum(input, dim=dim, keepdim=True, out=maxval if not implicit_in else None) if implicit_in: sumval += maxval.neg().exp() # don't forget the class full of zeros input *= sumval.reciprocal_() if implicit_in and not implicit_out: background = input.sum(dim, keepdim=True).neg_().add_(1) input = torch.cat((input, background), dim=dim) elif implicit_out and not implicit_in: input = utils.slice_tensor(input, slice(-1), dim) return input
def __init__(self, dim=1, implicit=False): """ Parameters ---------- dim : int, default=1 Dimension along which to take the softmax implicit : bool or (bool, bool), default=False The first value relates to the input tensor and the second relates to the output tensor. - implicit[0] == True assumes that an additional (hidden) channel with value zero exists. - implicit[1] == True drops the last class from the softmaxed tensor. """ super().__init__() self.implicit = py.make_list(implicit, 2) self.dim = dim
def forward(self, *image, **overload): """ Parameters ---------- *image : (batch, channel, *spatial) **overload : dict Returns ------- *image : (batch, channel, *patch_shape) """ image, *other_images = image image = torch.as_tensor(image) device = image.device dim = image.dim() - 2 shape = py.make_list(overload.get('shape', self.shape), dim) shape = [min(s0, s1) for s0, s1 in zip(image.shape[2:], shape)] # sample shift max_shift = [d0 - d1 for d0, d1 in zip(image.shape[2:], shape)] shift = [[torch.randint(0, s, [], device=device) if s > 0 else 0 for s in max_shift] for _ in range(len(image))] output = image.new_empty([*image.shape[:2], *shape]) other_outputs = [im.new_empty([*im.shape[:2], *shape]) for im in other_images] for b in range(len(image)): # subslice index = (b, slice(None)) # batch, channel index = index + tuple(slice(s, s+d) for s, d in zip(shift[b], shape)) output[b] = image[index] for i in range(len(other_images)): other_outputs[i][b] = other_images[i][index] if len(other_images) > 0: return (output, *other_outputs) else: return output