def preprocess_labels(self, s): # s0 <- labels to predict # s <- labels to model s0 = s all_labels = set(s.unique().tolist()) if self.droppable_labels: # remove labels that are not modelled # E.g., this could be all non-brain labels (to model skull # stripping) or all left labels (to model hemi). ngroups = len(self.droppable_labels) group = torch.randint(ngroups + 1, []) if group > 0: dropped_labels = self.droppable_labels[group - 1] s[utils.isin(s, dropped_labels)] = 0 all_labels -= set(dropped_labels) s = utils.merge_labels(s, sorted(all_labels)) nb_labels_sampled = len(all_labels) if self.predicted_labels: predicted_labels = self.predicted_labels if isinstance(self.predicted_labels, int): predicted_labels = list(range(predicted_labels + 1)) # remove labels that are not predicted s0[utils.isin(s0, predicted_labels).bitwise_not_()] = 0 s0 = utils.merge_labels(s0, sorted([0, *predicted_labels])) else: predicted_labels = list(sorted(s0.unique().tolist()))[1:] nb_labels_predicted = len(predicted_labels) + 1 return s, s0, nb_labels_sampled, nb_labels_predicted
def get_hard_labels(x, one_hot_map, implicit=False): """Get MAP labels Parameters ---------- x : (B, C[-1] | 1, *spatial) tensor one_hot_map : list[list[int] or None] implicit : bool, default=False Returns ------- x : (B, 1, *spatial) tensor[int] """ if x.dtype in (torch.half, torch.float, torch.double): x = _pad_norm(x, implicit) x = x.argmax(dim=1) else: new_x = torch.zeros_like(x) for soft, hard in enumerate(one_hot_map): if hard is None: # implicit class hard = flatten([l for l in one_hot_map if l is not None]) new_x[isin(x, hard)] = soft x = new_x return x
def missing(dat, missing): """Return a mask of missing data""" missing = py.ensure_list(missing) if torch.is_tensor(dat): mask = utils.isin(dat, missing) else: mask = np.isin(dat, missing) return mask
def _init_template_cat(self, images, one_hot_map): self.template.data.zero_() n = 0 for i, image in enumerate(images): image = image.to(self.template.device) # check shape shape = image.shape[2:] if self.template.shape[1:] != shape: if i == 0: shape = [self.cat + (not self.implicit), *shape] self.template = tnn.Parameter(torch.zeros(shape)) else: raise ValueError('All images must have the same shape') # mean probabilities if image.dtype.is_floating_point: image = image.to(self.template.dtype) self.template.data[:self.cat] += image[:, :self.cat].sum(0) else: if not one_hot_map: one_hot_map = list(range(1, self.cat+1)) for soft, label in enumerate(one_hot_map): label = py.make_list(label) if len(label) == 1: self.template.data[soft] += (image == label).sum(0)[0] else: self.template.data[soft] += utils.isin(image, label).sum(0)[0] n += image.shape[0] k = self.cat self.template.data /= n self.template.data[:k] += 1e-5 self.template.data[:k] /= 1+1e-5 norm = self.template.data[:k].sum(0).neg_().add_(1) if not self.implicit: self.template.data[-1] = norm self.template.data.clamp_min_(1e-5).log_() else: self.template.data.clamp_min_(1e-5).log_() self.template.data -= norm.clamp_min_(1e-5).log_()
def inpaint(*inputs, missing='nan', output=None, device=None, verbose=1, max_iter_rls=10, max_iter_cg=32, tol_rls=1e-5, tol_cg=1e-5): """Inpaint missing values by minimizing Joint Total Variation. Parameters ---------- *inputs : str or tensor or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. missing : 'nan' or scalar or callable, default='nan' Mask of the missing data. If a scalar, all voxels with that value are considered missing. If a function, it should return the mask of missing values when applied to the multi-channel data. Else, non-finite values are assumed missing. output : [sequence of] str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.pool{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file, `i` is the coordinate (starting at 1) of the slice. verbose : int, default=1 device : torch.device, optional max_iter_rls : int, default=10 max_iter_cg : int, default=32 tol_rls : float, default=1e-5 tol_cg : float, default=1e-5 Returns ------- *output : str or (tensor, tensor) If the input is a path, the output path is returned. Else, the pooled data and orientation matrix are returned. """ # Preprocess dirs = [] bases = [] exts = [] fnames = [] nchannels = [] dat = [] aff = None for i, inp in enumerate(inputs): is_file = isinstance(inp, str) if is_file: fname = inp dir, base, ext = py.fileparts(fname) fnames.append(inp) dirs.append(dir) bases.append(base) exts.append(ext) f = io.volumes.map(fname) if aff is None: aff = f.affine f = ensure_4d(f) dat.append(f.fdata(device=device)) else: fnames.append(None) dirs.append('') bases.append(f'{i+1}') exts.append('.nii.gz') if isinstance(inp, (list, tuple)): if aff is None: dat1, aff = inp else: dat1, _ = inp else: dat1 = inp dat.append(torch.as_tensor(dat1, device=device)) del dat1 nchannels.append(dat[-1].shape[-1]) dat = utils.to(*dat, dtype=torch.float, device=utils.max_device(dat)) if not torch.is_tensor(dat): dat = torch.cat(dat, dim=-1) dat = utils.movedim(dat, -1, 0) # (channels, *spatial) # Set missing data if missing != 'nan': if not callable(missing): missingval = utils.make_vector(missing, dtype=dat.dtype, device=dat.device) missing = lambda x: utils.isin(x, missingval) dat[missing(dat)] = nan dat[~torch.isfinite(dat)] = nan # Do it if aff is not None: vx = spatial.voxel_size(aff) else: vx = 1 dat = do_inpaint(dat, voxel_size=vx, verbose=verbose, max_iter_rls=max_iter_rls, tol_rls=tol_rls, max_iter_cg=max_iter_cg, tol_cg=tol_cg) # Postprocess dat = utils.movedim(dat, 0, -1) dat = dat.split(nchannels, dim=-1) output = py.make_list(output, len(dat)) for i in range(len(dat)): if fnames[i] and not output[i]: output[i] = '{dir}{sep}{base}.inpaint{ext}' if output[i]: if fnames[i]: output[i] = output[i].format(dir=dirs[i] or '.', base=bases[i], ext=exts[i], sep=os.path.sep) io.volumes.save(dat[i], output[i], like=fnames[i], affine=aff) else: output[i] = output[i].format(sep=os.path.sep) io.volumes.save(dat[i], output[i], affine=aff) dat = [ output[i] if fnames[i] else (dat[i], aff) if aff is not None else dat[i] for i in range(len(dat)) ] if len(dat) == 1: dat = dat[0] return dat
def fdata(self, dtype=None, device=None, rand=False, missing=None, cache=False, copy=False, **kwargs): """Get scaled floating-point data. Note that if a mask is registered in the object, all voxels outside of this mask will be set to NaNs. Parameters ---------- dtype : torch.dtype, default='`torch.get_default_dtype()` device : torch.device, default='cpu' rand : bool, default=True Add random noise if raw data is integer missing : scalar or sequence, default=0 Values that should be considered missing data. All of these values will be transformed to NaNs. cache : bool, default=False Cache the data in memory so that it does not need to be loaded again next time copy : bool, default=False Ensure that a copy of the original data is performed. Returns ------- dat : torch.tensor[dtype] """ dtype = dtype or self.dtype device = device or self.device backend = dict(dtype=dtype, device=device) if missing is not None: missing = py.ensure_list(missing) do_copy = copy or rand or (missing is not None) if not cache or self._fdata is None: if isinstance(self.volume, io.MappedArray): _fdata = self.volume.fdata(**backend) else: _fdata = self.volume.to(**backend, copy=do_copy) do_copy = False mask = torch.isfinite(_fdata).bitwise_not_() if missing: mask.bitwise_or_(utils.isin(_fdata, missing)) if self._mask is not None: if isinstance(self._mask, io.MappedArray): _mask = self._mask.data(dtype=torch.bool, device=mask.device) else: _mask = self._mask.to(dtype=torch.bool, device=mask.device, copy=True) _mask.bitwise_not_() mask.bitwise_or_(_mask) disk_dtype = self.volume.dtype if isinstance(disk_dtype, (list, tuple)): disk_dtype = disk_dtype[0] if rand and not dtype_info(disk_dtype).is_floating_point: slope = getattr(self.volume, 'slope', None) or 1 _fdata.add_(torch.rand_like(_fdata).mul_(slope)) _fdata[mask] = float('nan') if cache: self._fdata = _fdata do_copy = copy else: _fdata = self._fdata return _fdata.to(**backend, copy=do_copy)
def forward(self, predicted, reference, **overload): """ Parameters ---------- predicted : (batch, nb_class[-1], *spatial) tensor Predicted classes. reference : (batch, nb_class[-1]|1, *spatial) tensor Reference classes (or their expectation). * If `reference` has a floating point data type (`half`, `float`, `double`) it is assumed to hold one-hot or soft labels, and its channel dimension should be `nb_class` or `nb_class - 1`. * If `reference` has an integer or boolean data type, it is assumed to hold hard labels and its channel dimension should be 1. Eventually, `one_hot_map` is used to map one-hot labels to hard labels. overload : dict All parameters defined at build time can be overridden at call time. Returns ------- loss : scalar or tensor The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar. """ log = overload.get('log', self.log) implicit = overload.get('implicit', self.implicit) weighted = overload.get('weighted', self.weighted) predicted = torch.as_tensor(predicted) reference = torch.as_tensor(reference, device=predicted.device) backend = dict(dtype=predicted.dtype, device=predicted.device) # if only one predicted class -> must be implicit implicit = implicit or (predicted.shape[1] == 1) # take softmax if needed predicted = get_prob_explicit(predicted, log=log, implicit=implicit) nb_classes = predicted.shape[1] spatial_dims = list(range(2, predicted.dim())) # prepare weights if not torch.is_tensor(weighted) and not weighted: weighted = False if not isinstance(weighted, bool): weighted = make_vector(weighted, nb_classes, **backend)[None] # preprocess reference if reference.dtype in (torch.half, torch.float, torch.double): # one-hot labels reference = reference.to(predicted.dtype) implicit_ref = reference.shape[1] == nb_classes - 1 reference = get_prob_explicit(reference, implicit=implicit_ref) if reference.shape[1] != nb_classes: raise ValueError('Number of classes not consistent. ' 'Expected {} or {} but got {}.'.format( nb_classes, nb_classes - 1, reference.shape[1])) inter = nansum(predicted * reference, dim=spatial_dims) union = nansum(predicted + reference, dim=spatial_dims) loss = -2 * inter / union if weighted is not False: if weighted is True: weights = nansum(reference, dim=spatial_dims) weights = weights / weights.sum(dim=1, keepdim=True) else: weights = weighted loss = loss * weights else: # hard labels one_hot_map = overload.get('one_hot_map', self.one_hot_map) one_hot_map = get_one_hot_map(one_hot_map, nb_classes) loss = [] weights = [] for soft, hard in enumerate(one_hot_map): pred1 = predicted[:, None, soft, ...] if hard is None: # implicit class all_labels = filter(lambda x: x is not None, one_hot_map) ref1 = ~isin(reference, flatten(list(all_labels))) else: ref1 = isin(reference, hard) inter = math.sum(pred1 * ref1, dim=spatial_dims) union = math.sum(pred1 + ref1, dim=spatial_dims) loss1 = -2 * inter / union if weighted is not False: if weighted is True: weight1 = ref1.sum() else: weight1 = float(weighted[soft]) loss1 = loss1 * weight1 weights.append(weight1) loss.append(loss1) loss = torch.cat(loss, dim=1) if weighted is True: weights = sum(weights) loss = loss / weights loss += 1 return super().forward(loss, **overload)
def forward(self, predicted, reference, **overload): """ Parameters ---------- predicted : (nb_batch, nb_class_pred[-1], *spatial) tensor (Log)-prior probabilities reference : (nb_batch, nb_class_ref[-1]|1, *spatial) tensor Observed classes (or their expectation). * If `reference` has a floating point data type (`half`, `float`, `double`) it is assumed to hold one-hot or soft labels, and its channel dimension should be `nb_class` or `nb_class - 1`. * If `reference` has an integer or boolean data type, it is assumed to hold hard labels and its channel dimension should be 1. Eventually, `one_hot_map` is used to map one-hot labels to hard labels. overload : dict All parameters defined at build time can be overridden at call time. Returns ------- loss : scalar or tensor The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar. """ log = overload.get('log', self.log) implicit = overload.get('implicit', self.implicit) confusion = overload.get('confusion', self.confusion) implicit_pred, implicit_ref = make_list(implicit, 2) predicted = torch.as_tensor(predicted) reference = torch.as_tensor(reference, device=predicted.device) backend = dict(dtype=predicted.dtype, device=predicted.device) predicted = get_prob_explicit(predicted, log=log, implicit=implicit_pred) nb_classes_pred = predicted.shape[1] dim = predicted.dim() - 2 if reference.dtype in (torch.half, torch.float, torch.double): # soft labels reference = reference.to(predicted.dtype) reference = get_prob_explicit(reference, log=log, implicit=implicit_ref) nb_classes_ref = reference.shape[1] confusion = get_log_confusion(confusion, nb_classes_pred, nb_classes_ref, dim, **backend) loss = (predicted[:, :, None] * confusion).sum(dim=1) loss = (loss * reference).sum(dim=1) else: # hard labels if reference.shape[1] != 1: raise ValueError('Hard label maps cannot be multi-channel.') reference = reference[:, None] nb_classes_ref = nb_classes_pred one_hot_map = overload.get('one_hot_map', self.one_hot_map) one_hot_map = get_one_hot_map(one_hot_map, nb_classes_ref) confusion = get_log_confusion(confusion, nb_classes_pred, nb_classes_ref, dim, **backend) predicted = (predicted[:, :, None] * confusion).sum(dim=1) loss = 0 for soft, hard in enumerate(one_hot_map): if hard is None: # implicit class all_labels = list( filter(lambda x: x is not None, one_hot_map)) obs1 = ~isin(reference, flatten(all_labels)) else: obs1 = isin(reference, hard) loss += predicted[:, soft] * obs1 # negate loss = -loss # reduction return super().forward(loss, **overload)
def forward(self, prior, obs, **overload): """ Parameters ---------- prior : (nb_batch, nb_class[-1], *spatial) tensor (Log)-prior probabilities obs : (nb_batch, nb_class[-1]|1, *spatial) tensor Observed classes (or their expectation). * If `obs` has a floating point data type (`half`, `float`, `double`) it is assumed to hold one-hot or soft labels, and its channel dimension should be `nb_class` or `nb_class - 1`. * If `obs` has an integer or boolean data type, it is assumed to hold hard labels and its channel dimension should be 1. Eventually, `one_hot_map` is used to map one-hot labels to hard labels. overload : dict All parameters defined at build time can be overridden at call time. Returns ------- loss : scalar or tensor The output shape depends on the type of reduction used. If 'mean' or 'sum', this function returns a scalar. """ log = overload.get('log', self.log) implicit = overload.get('implicit', self.implicit) prior = torch.as_tensor(prior) obs = torch.as_tensor(obs, device=prior.device) # take log if needed logprior = get_logprob_explicit(prior, log=log, implicit=implicit) nb_classes = logprior.shape[1] if obs.dtype in (torch.half, torch.float, torch.double): # soft labels obs = obs.to(prior.dtype) obs = get_prob_explicit(obs, implicit=obs.shape[1] == nb_classes - 1) if obs.shape[1] != nb_classes: raise ValueError('Number of classes not consistent. ' 'Expected {} or {} but got {}.'.format( nb_classes, nb_classes - 1, obs.shape[1])) loss = logprior * obs else: # hard labels if obs.shape[1] != 1: raise ValueError('Hard label maps cannot be multi-channel.') obs = obs[:, None] one_hot_map = overload.get('one_hot_map', self.one_hot_map) one_hot_map = get_one_hot_map(one_hot_map, nb_classes) loss = torch.empty_like(logprior) for soft, hard in enumerate(one_hot_map): if hard is None: # implicit class all_labels = list( filter(lambda x: x is not None, one_hot_map)) obs1 = ~isin(obs, flatten(all_labels)) else: obs1 = isin(obs, hard) loss[:, soft] = logprior[:, soft] * obs1 # negate loss = -loss # reduction return super().forward(loss, **overload)