def LabelDice(A, B, class_labels=None): ''' :param A: (n_batch, n_1, ..., n_k) :param B: (n_batch, n_1, ..., n_k) :param class_labels: list[n_class] :return: (n_batch, n_class) ''' assert A.has_batch and B.has_batch if not class_labels: class_labels = sorted(A.unique().tolist() + B.unique().tolist()) A_labels = [1 - tp.clamp(tp.abs(A - i), 0, 1) for i in class_labels] B_labels = [1 - tp.clamp(tp.abs(B - i), 0, 1) for i in class_labels] A_maps = tp.stack(A_labels, {1}) B_maps = tp.stack(B_labels, {1}) return Dice(A_maps, B_maps)
def local_matrix(A, B, s=0, kernel="Gaussian", kernel_size=3): if isinstance(kernel, str): if kernel.lower() == "gaussian": kernel = tp.gaussian_kernel(n_dims=A.nspace, kernel_size=kernel_size).unsqueeze( 0, 0) elif kernel.lower() == "mean": kernel = tp.ones(*(kernel_size, ) * A.nspace).unsqueeze( 0, 0) / (kernel_size**A.nspace) elif hasattr(kernel, 'shape'): kernel_size = kernel.size(-1) def mean(a): op = eval("tp.nn.functional.conv%dd" % A.nspace) if a.has_batch: x = a.unsqueeze({1}) else: x = a.unsqueeze([0], {1}) return op(x, kernel, padding=kernel_size // 2).squeeze(*((1, ) if a.has_batch else (0, 0))) if s > 0: GA = tp.grad_image(A) GB = tp.grad_image(B) point_estim = tp.stack(tp.dot(GA, GA), tp.dot(GA, GB), tp.dot(GB, GB), dim={int(A.has_batch)}) else: point_estim = 0 MA = mean(A) MB = mean(B) local_estim = tp.stack(mean(A * A) - MA**2, mean(A * B) - MA * MB, mean(B * B) - MB**2, dim={int(A.has_batch)}) return s * point_estim + local_estim
def grad_image(array): ''' Gradient image of array array: (n_batch, n_feature, n_1, ..., n_{n_dim}) output: (n_batch, n_dim, n_feature, n_1, ..., n_{n_dim}) ''' array = tp.tensor(array) output = tp.zeros_like(array) grad_dim = int(array.has_batch) output = [] for d in range(array.ndim): if d in array.special: continue b = (slice(None, None),) * d + (slice(2, None),) + (slice(None, None),) * (array.ndim - d - 1) a = (slice(None, None),) * d + (slice(None, -2),) + (slice(None, None),) * (array.ndim - d - 1) output.append(tp.crop_as((array[b] - array[a]) / 2, array)) return tp.stack(output, {grad_dim})
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): input = input.contiguous() count = torch.empty(1, dtype=running_mean.dtype, device=input.device).fill_(input.numel() // input.size(1)) # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) # world_size * (2C + 1) combined_list = [torch.empty_like(combined) for k in range(world_size)] # Use allgather instead of allreduce since I don't trust in-place operations .. dist.all_gather(combined_list, combined, process_group, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) size = count_all.view(-1).long().sum() if size == 1: raise ValueError( 'Expected more than 1 value per channel when training, got input size {}' .format(size)) # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1)) self.save_for_backward(input, weight, mean, invstd, count_all) self.process_group = process_group # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out
def clip_grad_norm_(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor: r"""Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] max_norm = float(max_norm) norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm( torch.stack([ torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters ]), norm_type) clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.detach().mul_(clip_coef.to(p.grad.device)) return total_norm
def forward(ctx, I1, I2, nbin=100): with tp.no_grad(): if hasattr(ctx, 'JH'): del ctx.JH nbin = tp.tensor(nbin) data_pair = tp.stack(I1.flatten(1), I2.flatten(1), dim={1}) nbatch, nhist, ndata = data_pair.ishape indices = [] values = [] ctx.window = (tp.image_grid(4, 4) - 1).flatten(1).transpose(0, 1) for shift in ctx.window: # [nbatch] x {nhist} x ndata hist_pos = data_pair * nbin index = tp.clamp( tp.floor(hist_pos).long() + shift, 0, nbin - 1) batch_idx = tp.arange(nbatch).expand_to([nbatch], {1}, ndata) index = tp.cat(batch_idx, index, 1) value = Bspline(shift.expand_to(data_pair), tp.decimal(hist_pos)).prod(1) indices.append(index) values.append(value) # n_batch x (1 + n_hist) x (n_data x 4 ** n_hist) Mindices = tp.cat(indices, -1) # n_batch x (n_data x 4 ** n_hist) Mvalues = tp.cat(values, -1) # (1 + n_hist) x (n_batch x n_data x 4 ** n_hist) indices = Mindices.transpose(0, 1).flatten(1) # (n_batch x n_data x 4 ** n_hist) values = Mvalues.flatten(0) if tp.Device == tp.DeviceCPU: creator = torch.sparse.FloatTensor else: creator = torch.cuda.sparse.FloatTensor collected = creator(indices, values, (nbatch, nbin, nbin)).to_dense() collected = tp.Tensor(collected, batch_dim=0) ctx.nbin = nbin ctx.Ishape = I1.shape ctx.data_pair = data_pair ctx.JH = collected / ndata return ctx.JH
############################## ## Author: Yuncheng Zhou ############################## import sys sys.path.append("../..") # import torchplus as tp import torch import torchplus as tp print(tp.__file__) from pyctlib import scope import copy print(tp.stack(tp.zeros(3, 4), tp.ones(3, 4), dim={1})) #tp.set_autodevice(False) #tp.manual_seed(0) #with scope("test tp, cpu"): # t = tp.randn([3000, 400], requires_grad=True) # a = t # LP = tp.nn.Linear(400, 400) # for _ in range(10): a = LP(a) # a.sum().backward() # #torch.manual_seed(0) #with scope("test torch, cpu"): # t_ = torch.randn([3000, 400], requires_grad=True) # a_ = t_ # LP_ = torch.nn.Linear(400, 400)
def image_grid__default__(*shape): if len(shape) == 1 and isinstance(shape, (list, tuple)): shape = shape[0] ret = tp.stack(tp.meshgrid(*[tp.arange(x) for x in shape])) return ret.channel_dimension_(0)
def __new__(cls, instance, slice_only=False): if isinstance(instance, str): p = path(instance) if not p.isdir(): if not slice_only: p = p @ path.Folder else: slice_only = False dcmBundle = dcm.filereader.dcmread( path(__file__) @ path.Folder / "template.dcm") slice_arrays = {} slices = {} zs = {} readable = False direction_down = True for p in ([p] if slice_only else p): if not p.ext.lower() in ('dcm', 'ima'): continue try: image_slice = dcm.filereader.dcmread(p) except: continue readable = True n_slice = int(image_slice.InstanceNumber) if 'SeriesNumber' in image_slice: n_series = int(image_slice.SeriesNumber) else: n_series = 0 try: slice_array = image_slice.pixel_array except: try: p_dicom = (p @ path.Folder // 'dicom').mkdir() / p @ path.File if not p_dicom.exists(): _, stderr = shell(f"dcmdjpeg {p} {p_dicom}") else: stderr = '' if stderr: raise TypeError("Unknown encoding: %s." % p) except: raise TypeError("Unknown encoding: %s." % p) image_slice = dcm.filereader.dcmread(p_dicom) try: slice_array = image_slice.pixel_array except: raise TypeError("Unknown encoding: %s." % p) if n_series not in slices: slice_arrays[n_series] = {} slices[n_series] = {} zs[n_series] = {} slice_arrays[n_series][n_slice] = slice_array slices[n_series][n_slice] = image_slice if image_slice.ImageOrientationPatient[2] != 0: iz = 0 elif image_slice.ImageOrientationPatient[5] != 0: iz = 1 else: iz = 2 if 'ImagePositionPatient' in image_slice: z = float(image_slice.ImagePositionPatient[iz]) elif 'TablePosition' in image_slice: z = image_slice.TablePosition elif 'SliceLocation' in image_slice: z = float(image_slice.SliceLocation) else: z = 0. zs[n_series][n_slice] = z if not readable: raise TypeError("Could not create a DICOM object from " + p + ".") sorted_series = sorted([(n_series, slices[n_series]) for n_series in slices], key=lambda x: -len(x[1])) n_series = sorted_series[0][0] possible_series = [s[1] for s in sorted_series if s[0] == n_series] if len(possible_series) >= 8: series = possible_series[7] elif len(possible_series) >= 3: series = possible_series[2] else: series = possible_series[0] min_slice = 1000, None max_slice = 0, None top_slices = -float('inf'), {} bottom_slices = float('inf'), {} for n_slice in series: image_slice = series[n_slice] z = zs[n_series][n_slice] if n_slice < min_slice[0]: min_slice = n_slice, image_slice if n_slice > max_slice[0]: max_slice = n_slice, image_slice if z > top_slices[0]: top_slices = z, {n_slice: image_slice} if z < bottom_slices[0]: bottom_slices = z, {n_slice: image_slice} if z == top_slices[0]: top_slices[1][n_slice] = image_slice if z == bottom_slices[0]: bottom_slices[1][n_slice] = image_slice N = min(len(top_slices[1].keys()), len(bottom_slices[1].keys())) if N >= 8: i_series = 7 elif N >= 3: i_series = 2 else: i_series = 0 bound1 = sorted(top_slices[1].keys())[i_series] bound2 = sorted(bottom_slices[1].keys())[i_series] if bound1 > bound2: zs = { k: v for k, v in zs[n_series].items() if bound2 <= k <= bound1 } slices = { k: v for k, v in slice_arrays[n_series].items() if bound2 <= k <= bound1 } max_slice = bound1, top_slices[1][bound1] min_slice = bound2, bottom_slices[1][bound2] elif bound1 < bound2: zs = { k: v for k, v in zs[n_series].items() if bound1 <= k <= bound2 } slices = { k: v for k, v in slice_arrays[n_series].items() if bound1 <= k <= bound2 } max_slice = bound2, bottom_slices[1][bound2] min_slice = bound1, top_slices[1][bound1] else: zs = {k: v for k, v in zs[n_series].items()} slices = {k: v for k, v in slice_arrays[n_series].items()} bound = sorted(series.keys())[0] max_slice = min_slice = bound, series[bound] direction_down = zs[max_slice[0]] < zs[min_slice[0]] typical_slice = max_slice[1] if direction_down else min_slice[1] for key in dir(typical_slice): if key == 'PixelData' or '_' in key: continue if key.capitalize() != key[0] + key[1:].lower(): continue dcmBundle[key] = typical_slice[key] ozs = tp.Tensor(sorted(zs.values())) if len(set(ozs)) > 1: volume = tp.stack( orderedValue({zs[i]: slices[i] for i in slices}), -1) dcmBundle.SliceThickness = str( tp.abs(tp.mean(ozs[1:] - ozs[:-1])).item()) else: volume = tp.stack(orderedValue(slices), -1) volume = volume.astype( toU(volume.dtype) if dcmBundle. PixelRepresentation else toI(volume.dtype)) dcmBundle.PixelData = volume.tobytes() self = super().__new__(cls, volume) self.bundle = dcmBundle self.path = path self.slice_only = slice_only self.update() return self elif hasattr(instance, 'shape'): if instance.ndim == 0: return instance if isinstance(instance, DCM): return instance if isinstance(instance, NII): input = nii2dcmBundle(instance) else: data = tp.Tensor(instance) input.path = 'Unknown' self.slice_only = False self.update() return self else: raise TypeError(f"Unknown input for DCM: {instance}. ")