def test_transpose_axes(self): n_iter = 1000 tmp = 0 for i in range(n_iter): data_out, seg_out = augment_transpose_axes(self.data_3D, self.seg_3D, axes=(1, 0)) if np.array_equal(data_out, np.swapaxes(self.data_3D, 1, 2)): tmp += 1 self.assertAlmostEqual(tmp, n_iter/2., delta=10)
def __call__(self, **data_dict): data = data_dict.get(self.data_key) seg = data_dict.get(self.label_key) ret_val = augment_transpose_axes(data, seg, self.transpose_any_of_these) data_dict[self.data_key] = ret_val[0] if seg is not None: data_dict[self.label_key] = ret_val[1] return data_dict
def __call__(self, **data_dict): data = data_dict.get(self.data_key) seg = data_dict.get(self.label_key) for b in range(len(data)): if np.random.uniform() < self.p_per_sample: if seg is not None: s = seg[b] else: s = None ret_val = augment_transpose_axes(data[b], s, self.transpose_any_of_these) data[b] = ret_val[0] if seg is not None: seg[b] = ret_val[1] data_dict[self.data_key] = data if seg is not None: data_dict[self.label_key] = seg return data_dict