def get_trained_irs( data_list, output_path, cutoffp=(1, 20), landmarkp=[2, 3, 4, 5, 6, 8, 10, 12, 14, 15, 16, 17, 18, 19] ): # Default : cutoffp = (1, 99), landmarkp = [10, 20, 30, 40, 50, 60, 70, 90] flair_irs = IntensityRangeStandardization(cutoffp=cutoffp, landmarkp=landmarkp) t1_irs = IntensityRangeStandardization(cutoffp=cutoffp, landmarkp=landmarkp) t1c_irs = IntensityRangeStandardization(cutoffp=cutoffp, landmarkp=landmarkp) t2_irs = IntensityRangeStandardization(cutoffp=cutoffp, landmarkp=landmarkp) for dataClass in ['HGG', 'LGG', 'TEST']: mods = MODS[dataClass] total = len(data_list[DATATYPE[dataClass]]) // mods if mods == 5: # OT data는 별도의 전처리과정을 필요로 하지 않으므로, 따리 관리하도록 한다. mods = 4 fp = np.memmap(output_path + dataClass + '_orig.dat', dtype=np.float32, mode='r', shape=(total, mods, SHAPE[0], SHAPE[1], SHAPE[2])) print('\r', get_time() + ': training irs with {} images'.format(dataClass)) # 이미 사전에 학습된 표준강도공간이 존재하면 이를 불러와서 계속해서 학습한다. for mod in range(mods): images = fp[:, mod, :, :, :] if mod == MOD['MR_Flair']: flair_irs = flair_irs.train([images[images > 0]]) elif mod == MOD['MR_T1']: t1_irs = t1_irs.train([images[images > 0]]) elif mod == MOD['MR_T1c']: t1c_irs = t1c_irs.train([images[images > 0]]) elif mod == MOD['MR_T2']: t2_irs = t2_irs.train([images[images > 0]]) with open(output_path + 'Flair_irs.pkl', 'wb') as f1: pickle.dump(flair_irs, f1) with open(output_path + 'T1_irs.pkl', 'wb') as f2: pickle.dump(t1_irs, f2) with open(output_path + 'T1c_irs.pkl', 'wb') as f3: pickle.dump(t1c_irs, f3) with open(output_path + 'T2_irs.pkl', 'wb') as f4: pickle.dump(t2_irs, f4)
def train_irs(data_list, dataset): logthis("Train LGG IRS Started") irs = IntensityRangeStandardization() imgcnt = len(data_list[hl[dataset]]) // 5 for i in range(imgcnt): for mod in range(MODCNT): curimg = get_img(ORIG_READ_PATH, data_list[hl[dataset]][i, mod]) irs = irs.train([curimg[curimg > 0]]) print("\rIRS Train", i + 1, "/", imgcnt, end="") with open(os.path.join(WRITE_PATH, "intensitymodel.txt"), 'wb') as f: pickle.dump(irs, f)
def train_IRS(): hl_data = np.memmap(filename=HDD_OUTPUT_PATH + "hl_orig.dat", dtype=np.float32, mode="r").reshape(-1, SHAPE[0], SHAPE[1], SHAPE[2], SHAPE[3]) logthis("HL IRS training started!") irs = IntensityRangeStandardization() for cur_cnt in range(hl_data.shape[0]): for mod_cnt in range(MOD_CNT): curmod = hl_data[cur_cnt, ..., mod_cnt] irs = irs.train([curmod[curmod > 0]]) print("\rHL", cur_cnt, end="") with open(os.path.join(HDD_OUTPUT_PATH, "hl_irs.dat"), 'wb') as f: pickle.dump(irs, f) logthis("HL IRS training ended!")
def test_Method(self): """Test the normal functioning of the method.""" # test training with good and bad images irs = IntensityRangeStandardization() irs.train(TestIntensityRangeStandardization.good_trainingset + [TestIntensityRangeStandardization.bad_image]) irs.transform(TestIntensityRangeStandardization.bad_image) # test equal methods irs = IntensityRangeStandardization() irs_ = irs.train(TestIntensityRangeStandardization.good_trainingset) self.assertEqual(irs, irs_) irs = IntensityRangeStandardization() irs.train(TestIntensityRangeStandardization.good_trainingset) timages = [] for i in TestIntensityRangeStandardization.good_trainingset: timages.append(irs.transform(i)) irs = IntensityRangeStandardization() irs_, timages_ = irs.train_transform(TestIntensityRangeStandardization.good_trainingset) self.assertEqual(irs, irs_, 'instance returned by transform() method is not the same as the once initialized') for ti, ti_ in zip(timages, timages_): numpy.testing.assert_allclose(ti, ti_, err_msg = 'train_transform() failed to produce the same results as transform()') # test pickling irs = IntensityRangeStandardization() irs_ = irs.train(TestIntensityRangeStandardization.good_trainingset) timages = [] for i in TestIntensityRangeStandardization.good_trainingset: timages.append(irs.transform(i)) with tempfile.TemporaryFile() as f: pickle.dump(irs, f) f.seek(0, 0) irs_ = pickle.load(f) timages_ = [] for i in TestIntensityRangeStandardization.good_trainingset: timages_.append(irs_.transform(i)) for ti, ti_ in zip(timages, timages_): numpy.testing.assert_allclose(ti, ti_, err_msg = 'pickling failed to preserve the instances model')
def test_MethodLimits(self): """Test the limits of the method.""" irs = IntensityRangeStandardization() irs.train(TestIntensityRangeStandardization.good_trainingset) self.assertRaises(InformationLossException, irs.transform, image = TestIntensityRangeStandardization.bad_image) irs = IntensityRangeStandardization() irs.train(TestIntensityRangeStandardization.good_trainingset) self.assertRaises(SingleIntensityAccumulationError, irs.transform, image = TestIntensityRangeStandardization.uniform_image) irs = IntensityRangeStandardization() irs.train(TestIntensityRangeStandardization.good_trainingset) self.assertRaises(SingleIntensityAccumulationError, irs.transform, image = TestIntensityRangeStandardization.single_intensity_image) irs = IntensityRangeStandardization() self.assertRaises(SingleIntensityAccumulationError, irs.train, images = [TestIntensityRangeStandardization.uniform_image] * 10) irs = IntensityRangeStandardization() self.assertRaises(SingleIntensityAccumulationError, irs.train, images = [TestIntensityRangeStandardization.single_intensity_image] * 10)