def test_rotation(self):
        data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=self.num_rot,
                                                  axes=[0, 1])

        for i in range(self.data_3D.shape[1]):
            self.assertTrue(np.array_equal(self.data_3D[:, i, :, :], np.flip(data_rotated[:, :, i, :], axis=1)))
            self.assertTrue(np.array_equal(self.seg_3D[:, i, :, :], np.flip(seg_rotated[:, :, i, :], axis=1)))
 def test_randomness_rotation_axis(self):
     tmp = 0
     for j in range(100):
         data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=self.num_rot,
                                                   axes=[0, 1, 2])
         if np.array_equal(self.data_3D[:, 0, :, :], np.flip(data_rotated[:, :, 0, :], axis=1)):
             tmp += 1
     self.assertAlmostEqual(tmp, 33, places=2)
 def test_randomness_rotation_number(self):
     tmp = 0
     num_rot = [1, 3]
     n_iter = 1000
     for j in range(n_iter):
         data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=num_rot,
                                                   axes=[0, 1])
         normal_rotated = np.array_equal(self.data_3D[:, 0, :, :], data_rotated[:, :, - 1, :])
         if normal_rotated:
             tmp += 1
     self.assertAlmostEqual(tmp, n_iter / 2., delta=20)
 def test_rotation_list(self):
     num_rot = [1, 3]
     data_rotated, seg_rotated = augment_rot90(np.copy(self.data_3D), np.copy(self.seg_3D), num_rot=num_rot,
                                               axes=[0, 1])
     tmp = 0
     for i in range(self.data_3D.shape[1]):
         # check for normal and inverse rotations
         normal_rotated = np.array_equal(self.data_3D[:, i, :, :], data_rotated[:, :, -i-1, :])
         inverse_rotated = np.array_equal(self.data_3D[:, i, :, :], np.flip(data_rotated[:, :, i, :], axis=1))
         if normal_rotated:
             tmp += 1
         self.assertTrue(normal_rotated or inverse_rotated)
         self.assertTrue(np.array_equal(self.seg_3D[:, i, :, :], seg_rotated[:, :, -i - 1, :]) or
                         np.array_equal(self.seg_3D[:, i, :, :], np.flip(seg_rotated[:, :, i, :], axis=1)))
    def test_rotation_checkerboard(self):
        data_2d_checkerboard = np.zeros((1, 2, 2))
        data_2d_checkerboard[0, 0, 0] = 1
        data_2d_checkerboard[0, 1, 1] = 1

        data_rotated_list = []
        n_iter = 1000
        for i in range(n_iter):
            d_r, _ = augment_rot90(np.copy(data_2d_checkerboard), None, num_rot=[4,1], axes=[0, 1])
            data_rotated_list.append(d_r)

        data_rotated_np = np.array(data_rotated_list)
        sum_data_list = np.sum(data_rotated_np, axis=0)
        a = np.unique(sum_data_list)
        self.assertAlmostEqual(a[0], n_iter/2, delta=20)
        self.assertTrue(len(a) == 2)
Exemple #6
0
    def __call__(self, **data_dict):
        data = data_dict.get(self.data_key)
        seg = data_dict.get(self.label_key)

        for b in range(data.shape[0]):
            if np.random.uniform() < self.p_per_sample:
                d = data[b]
                if seg is not None:
                    s = seg[b]
                else:
                    s = None
                d, s = augment_rot90(d, s, self.num_rot, self.axes)
                data[b] = d
                if s is not None:
                    seg[b] = s

        data_dict[self.data_key] = data
        if seg is not None:
            data_dict[self.label_key] = seg
        return data_dict
Exemple #7
0
    def __init__(self,
                 mode='train',
                 data_path='/data/ESMH/cropped_patches/na_pd_2d/fold_0',
                 label_path='/data/ESMH/subtypes.json',
                 dim='2d',
                 use_phases=None,
                 num_phases=3):
        self.mode = mode
        self.data_path = data_path
        self.label_path = label_path
        self.dim = dim
        self.use_phases = use_phases
        self.num_phases = num_phases

        files = os.listdir(os.path.join(data_path, mode, 'image'))
        files = sorted(files)
        self.case_list = [case for case in files if case[-8] == '0']

        with open(label_path) as json_file:
            self.subtype_list = json.load(json_file)

    def __len__(self):
        return len(self.case_list)

    def __getitem__(self, idx):
        case = self.case_list[idx]
        label = self.subtype_list[case[:-4]]

        case_n = os.path.join(self.data_path, self.mode, 'image', case)
        case_a = os.path.join(self.data_path, self.mode, 'image',
                              'case_0' + '1' + case[7:])
        case_p = os.path.join(self.data_path, self.mode, 'image',
                              'case_0' + '2' + case[7:])
        case_d = os.path.join(self.data_path, self.mode, 'image',
                              'case_0' + '3' + case[7:])

        seg_path_n = os.path.join(self.data_path, self.mode, 'seg', case)
        seg_path_a = os.path.join(self.data_path, self.mode, 'seg',
                                  'case_0' + '1' + case[7:])
        seg_path_p = os.path.join(self.data_path, self.mode, 'seg',
                                  'case_0' + '2' + case[7:])
        seg_path_d = os.path.join(self.data_path, self.mode, 'seg',
                                  'case_0' + '3' + case[7:])

        img_list = []
        phase_list = []

        if os.path.isfile(case_n):
            img_n = np.load(case_n)
            seg_n = np.load(seg_path_n)
            img_n = torch.from_numpy(img_n).to(torch.float32)
            img_list.append(img_n)
            phase_list.append(0)

        if os.path.isfile(case_a):
            img_a = np.load(case_a)
            seg_a = np.load(seg_path_a)
            img_a = torch.from_numpy(img_a).to(torch.float32)
            img_list.append(img_a)
            phase_list.append(1)
        else:
            img_a = torch.zeros_like(img_n)

        if os.path.isfile(case_p):
            img_p = np.load(case_p)
            seg_p = np.load(seg_path_p)
            img_p = torch.from_numpy(img_p).to(torch.float32)
            img_list.append(img_p)
            phase_list.append(2)
        else:
            img_p = torch.zeros_like(img_n)

        if os.path.isfile(case_d):
            img_d = np.load(case_d)
            seg_d = np.load(seg_path_d)
            img_d = torch.from_numpy(img_d).to(torch.float32)
            img_list.append(img_d)
            phase_list.append(3)
        else:
            img_d = torch.zeros_like(img_n)

        if self.dim == '2d':
            target_size = 224

            for i, img in enumerate(img_list):
                img_list[i] = F.interpolate(img.unsqueeze(0),
                                            size=target_size,
                                            mode='bilinear',
                                            align_corners=False).squeeze(0)

        elif self.dim == '3d':
            target_size = (16, 64, 64)

            img_n = torch.from_numpy(img_n).to(torch.float32)
            img_p = torch.from_numpy(img_p).to(torch.float32)
            img_d = torch.from_numpy(img_d).to(torch.float32)

            img_n = F.interpolate(img_n.unsqueeze(0).unsqueeze(0),
                                  size=target_size,
                                  mode='trilinear',
                                  align_corners=False).squeeze(0)
            img_p = F.interpolate(img_p.unsqueeze(0).unsqueeze(0),
                                  size=target_size,
                                  mode='trilinear',
                                  align_corners=False).squeeze(0)
            img_d = F.interpolate(img_d.unsqueeze(0).unsqueeze(0),
                                  size=target_size,
                                  mode='trilinear',
                                  align_corners=False).squeeze(0)

            image = torch.cat((img_n, img_p, img_d))
        else:
            print('dim error')

        if len(img_list) > self.num_phases:
            if self.use_phases is None:
                use_phases = [0] + sorted(
                    random.sample(list(range(1, len(img_list))),
                                  self.num_phases - 1))
                # use_phases = sorted(random.sample(list(range(0, len(img_list))), self.num_phases))
                temp_list = [
                    img for i, img in enumerate(img_list) if i in use_phases
                ]
            else:
                temp_list = [
                    img for i, img in enumerate(img_list)
                    if phase_list[i] in self.use_phases
                ]
                if len(temp_list) < self.num_phases:
                    use_phases = [0] + sorted(
                        random.sample(list(range(1, len(img_list))),
                                      self.num_phases - 1))
                    # use_phases = sorted(random.sample(list(range(0, len(img_list))), self.num_phases))
                    temp_list = [
                        img for i, img in enumerate(img_list)
                        if i in use_phases
                    ]
            image = torch.cat(temp_list)
        else:
            image = torch.cat(img_list)

        if self.mode != 'train':
            return image, label

        image = image.numpy()

        image, _ = augment_mirroring(image)

        image, _ = augment_rot90(image,
                                 sample_seg=None,
                                 num_rot=(0, 1, 2, 3),
                                 axes=(0, 1))

        if np.random.uniform() < 0.1:  # 0.1
            image = augment_gaussian_noise(image)

        if np.random.uniform() < 0.2:  # 0.2
            image = augment_gaussian_blur(image, (0.5, 1.), p_per_channel=0.5)

        image = torch.from_numpy(image.copy()).to(torch.float32)

        return image, label