def __getitem__(self, index): path = self.paths[index] # faster than niffty #images, label = pkload(path + 'data_f32_divm.pkl') #images, label = pkload(path + 'data_f32.pkl') #images, label = torch.tensor(images), torch.tensor(label) images, label = pkload(path + 'data_f32.pkl') images, label = torch.from_numpy(images), torch.from_numpy(label) if not self.crop: # transformation needs nhwtc images, label = images.unsqueeze(0), label.unsqueeze(0) images, label = self.transforms([images, label]) images, label = images.squeeze(0), label.squeeze(0) images = images.permute(3, 0, 1, 2).contiguous() return (images, self.all_coords), label if self.for_train: fg, bg = pkload(path + self.suffix + 'coords.pkl') coords = torch.cat( [sample(x, self.num_patches // 2) for x in (fg, bg)]) else: coords = self.all_coords samples = multicrop.crop3d_cpu(images, coords, self.sample_size, self.sample_size, self.sample_size, 1, False) sub_samples = multicrop.crop3d_cpu(images, coords, self.sub_sample_size, self.sub_sample_size, self.sub_sample_size, 3, False) if self.return_target: target = multicrop.crop3d_cpu(label, coords, self.target_size, self.target_size, self.target_size, 1, False) samples, sub_samples, target = self.transforms( [samples, sub_samples, target]) else: samples, sub_samples = self.transforms([samples, sub_samples]) target = coords if self.for_train: label = _zero samples = samples.permute(0, 4, 1, 2, 3).contiguous() sub_samples = sub_samples.permute(0, 4, 1, 2, 3).contiguous() return (samples, sub_samples, target), label
import numpy as np import torch import multicrop x = torch.randint(0, 5, (10, 15, 10, 3), dtype=torch.int16) _stride = 5 _shape = (10, 15, 10) coords = torch.tensor(np.stack([ v.reshape(-1) for v in np.meshgrid( *[_stride // 2 + np.arange(0, s, _stride) for s in _shape], indexing='ij') ], -1), dtype=torch.int16) #x = x.cuda() #coords = coords.cuda() y = multicrop.crop3d_cpu(x, coords, 5, 5, 5, 1, False) #for t in y: # print('='*10) # print(t) y = y.view(2, 3, 2, 5, 5, 5, 3).permute(0, 3, 1, 4, 2, 5, 6).reshape(10, 15, 10, 3) print((x == y).all().item())