def get_trained_irs(
    data_list,
    output_path,
    cutoffp=(1, 20),
    landmarkp=[2, 3, 4, 5, 6, 8, 10, 12, 14, 15, 16, 17, 18, 19]
):  # Default : cutoffp = (1, 99), landmarkp = [10, 20, 30, 40, 50, 60, 70, 90]
    flair_irs = IntensityRangeStandardization(cutoffp=cutoffp,
                                              landmarkp=landmarkp)
    t1_irs = IntensityRangeStandardization(cutoffp=cutoffp,
                                           landmarkp=landmarkp)
    t1c_irs = IntensityRangeStandardization(cutoffp=cutoffp,
                                            landmarkp=landmarkp)
    t2_irs = IntensityRangeStandardization(cutoffp=cutoffp,
                                           landmarkp=landmarkp)

    for dataClass in ['HGG', 'LGG', 'TEST']:
        mods = MODS[dataClass]
        total = len(data_list[DATATYPE[dataClass]]) // mods
        if mods == 5:  # OT data는 별도의 전처리과정을 필요로 하지 않으므로, 따리 관리하도록 한다.
            mods = 4
        fp = np.memmap(output_path + dataClass + '_orig.dat',
                       dtype=np.float32,
                       mode='r',
                       shape=(total, mods, SHAPE[0], SHAPE[1], SHAPE[2]))
        print('\r',
              get_time() + ': training irs with {} images'.format(dataClass))
        # 이미 사전에 학습된 표준강도공간이 존재하면 이를 불러와서 계속해서 학습한다.
        for mod in range(mods):
            images = fp[:, mod, :, :, :]
            if mod == MOD['MR_Flair']:
                flair_irs = flair_irs.train([images[images > 0]])
            elif mod == MOD['MR_T1']:
                t1_irs = t1_irs.train([images[images > 0]])
            elif mod == MOD['MR_T1c']:
                t1c_irs = t1c_irs.train([images[images > 0]])
            elif mod == MOD['MR_T2']:
                t2_irs = t2_irs.train([images[images > 0]])
    with open(output_path + 'Flair_irs.pkl', 'wb') as f1:
        pickle.dump(flair_irs, f1)
    with open(output_path + 'T1_irs.pkl', 'wb') as f2:
        pickle.dump(t1_irs, f2)
    with open(output_path + 'T1c_irs.pkl', 'wb') as f3:
        pickle.dump(t1c_irs, f3)
    with open(output_path + 'T2_irs.pkl', 'wb') as f4:
        pickle.dump(t2_irs, f4)
def train_irs(data_list, dataset):
    logthis("Train LGG IRS Started")
    irs = IntensityRangeStandardization()
    imgcnt = len(data_list[hl[dataset]]) // 5
    for i in range(imgcnt):
        for mod in range(MODCNT):
            curimg = get_img(ORIG_READ_PATH, data_list[hl[dataset]][i, mod])
            irs = irs.train([curimg[curimg > 0]])
        print("\rIRS Train", i + 1, "/", imgcnt, end="")
    with open(os.path.join(WRITE_PATH, "intensitymodel.txt"), 'wb') as f:
        pickle.dump(irs, f)
def train_IRS():
    hl_data = np.memmap(filename=HDD_OUTPUT_PATH + "hl_orig.dat",
                        dtype=np.float32,
                        mode="r").reshape(-1, SHAPE[0], SHAPE[1], SHAPE[2],
                                          SHAPE[3])

    logthis("HL IRS training started!")
    irs = IntensityRangeStandardization()
    for cur_cnt in range(hl_data.shape[0]):
        for mod_cnt in range(MOD_CNT):
            curmod = hl_data[cur_cnt, ..., mod_cnt]
            irs = irs.train([curmod[curmod > 0]])
        print("\rHL", cur_cnt, end="")
    with open(os.path.join(HDD_OUTPUT_PATH, "hl_irs.dat"), 'wb') as f:
        pickle.dump(irs, f)
    logthis("HL IRS training ended!")
 def test_Method(self):
     """Test the normal functioning of the method."""
     # test training with good and bad images
     irs = IntensityRangeStandardization()
     irs.train(TestIntensityRangeStandardization.good_trainingset + [TestIntensityRangeStandardization.bad_image])
     irs.transform(TestIntensityRangeStandardization.bad_image)
     
     # test equal methods
     irs = IntensityRangeStandardization()
     irs_ = irs.train(TestIntensityRangeStandardization.good_trainingset)
     self.assertEqual(irs, irs_)
     
     irs = IntensityRangeStandardization()
     irs.train(TestIntensityRangeStandardization.good_trainingset)
     timages = []
     for i in TestIntensityRangeStandardization.good_trainingset:
         timages.append(irs.transform(i))
         
     irs = IntensityRangeStandardization()
     irs_, timages_ = irs.train_transform(TestIntensityRangeStandardization.good_trainingset)
     
     self.assertEqual(irs, irs_, 'instance returned by transform() method is not the same as the once initialized')
     for ti, ti_ in zip(timages, timages_):
         numpy.testing.assert_allclose(ti, ti_, err_msg = 'train_transform() failed to produce the same results as transform()')
         
     
     # test pickling
     irs = IntensityRangeStandardization()
     irs_ = irs.train(TestIntensityRangeStandardization.good_trainingset)
     timages = []
     for i in TestIntensityRangeStandardization.good_trainingset:
         timages.append(irs.transform(i))
         
     with tempfile.TemporaryFile() as f:
         pickle.dump(irs, f)
         f.seek(0, 0)
         irs_ = pickle.load(f)
         
     timages_ = []
     for i in TestIntensityRangeStandardization.good_trainingset:
         timages_.append(irs_.transform(i))
         
     for ti, ti_ in zip(timages, timages_):
         numpy.testing.assert_allclose(ti, ti_, err_msg = 'pickling failed to preserve the instances model')     
 def test_MethodLimits(self):
     """Test the limits of the method."""   
     irs = IntensityRangeStandardization()
     irs.train(TestIntensityRangeStandardization.good_trainingset)
     self.assertRaises(InformationLossException, irs.transform, image = TestIntensityRangeStandardization.bad_image)
     
     irs = IntensityRangeStandardization()
     irs.train(TestIntensityRangeStandardization.good_trainingset)
     self.assertRaises(SingleIntensityAccumulationError, irs.transform, image = TestIntensityRangeStandardization.uniform_image)
     
     irs = IntensityRangeStandardization()
     irs.train(TestIntensityRangeStandardization.good_trainingset)
     self.assertRaises(SingleIntensityAccumulationError, irs.transform, image = TestIntensityRangeStandardization.single_intensity_image)
     
     irs = IntensityRangeStandardization()
     self.assertRaises(SingleIntensityAccumulationError, irs.train, images = [TestIntensityRangeStandardization.uniform_image] * 10)
     
     irs = IntensityRangeStandardization()
     self.assertRaises(SingleIntensityAccumulationError, irs.train, images = [TestIntensityRangeStandardization.single_intensity_image] * 10)