def __getitem__(self, index):
        path = self.paths[index]

        # faster than niffty
        #images, label = pkload(path + 'data_f32_divm.pkl')

        #images, label = pkload(path + 'data_f32.pkl')
        #images, label = torch.tensor(images), torch.tensor(label)

        images, label = pkload(path + 'data_f32.pkl')
        images, label = torch.from_numpy(images), torch.from_numpy(label)

        if not self.crop:
            # transformation needs nhwtc
            images, label = images.unsqueeze(0), label.unsqueeze(0)
            images, label = self.transforms([images, label])
            images, label = images.squeeze(0), label.squeeze(0)

            images = images.permute(3, 0, 1, 2).contiguous()
            return (images, self.all_coords), label

        if self.for_train:
            fg, bg = pkload(path + self.suffix + 'coords.pkl')
            coords = torch.cat(
                [sample(x, self.num_patches // 2) for x in (fg, bg)])
        else:
            coords = self.all_coords

        samples = multicrop.crop3d_cpu(images, coords, self.sample_size,
                                       self.sample_size, self.sample_size, 1,
                                       False)

        sub_samples = multicrop.crop3d_cpu(images, coords,
                                           self.sub_sample_size,
                                           self.sub_sample_size,
                                           self.sub_sample_size, 3, False)

        if self.return_target:
            target = multicrop.crop3d_cpu(label, coords, self.target_size,
                                          self.target_size, self.target_size,
                                          1, False)
            samples, sub_samples, target = self.transforms(
                [samples, sub_samples, target])
        else:
            samples, sub_samples = self.transforms([samples, sub_samples])
            target = coords

        if self.for_train: label = _zero

        samples = samples.permute(0, 4, 1, 2, 3).contiguous()
        sub_samples = sub_samples.permute(0, 4, 1, 2, 3).contiguous()
        return (samples, sub_samples, target), label
import numpy as np
import torch
import multicrop

x = torch.randint(0, 5, (10, 15, 10, 3), dtype=torch.int16)

_stride = 5
_shape = (10, 15, 10)

coords = torch.tensor(np.stack([
    v.reshape(-1) for v in np.meshgrid(
        *[_stride // 2 + np.arange(0, s, _stride) for s in _shape],
        indexing='ij')
], -1),
                      dtype=torch.int16)

#x = x.cuda()
#coords = coords.cuda()
y = multicrop.crop3d_cpu(x, coords, 5, 5, 5, 1, False)
#for t in y:
#    print('='*10)
#    print(t)
y = y.view(2, 3, 2, 5, 5, 5, 3).permute(0, 3, 1, 4, 2, 5,
                                        6).reshape(10, 15, 10, 3)

print((x == y).all().item())