def split_train_test(self,conf): shot_list_dir = conf['paths']['shot_list_dir'] shot_files = conf['paths']['shot_files'] shot_files_test = conf['paths']['shot_files_test'] train_frac = conf['training']['train_frac'] shuffle_training = conf['training']['shuffle_training'] use_shots = conf['data']['use_shots'] all_signals = conf['paths']['all_signals'] #split randomly use_shots_train = int(round(train_frac*use_shots)) use_shots_test = int(round((1-train_frac)*use_shots)) if len(shot_files_test) == 0: shot_list_train,shot_list_test = train_test_split(self.shots,train_frac,shuffle_training) #train and test list given else: shot_list_train = ShotList() shot_list_train.load_from_shot_list_files_objects(shot_files,all_signals) shot_list_test = ShotList() shot_list_test.load_from_shot_list_files_objects(shot_files_test,all_signals) shot_numbers_train = [shot.number for shot in shot_list_train] shot_numbers_test = [shot.number for shot in shot_list_test] print(len(shot_numbers_train),len(shot_numbers_test)) #make sure we only use pre-filtered valid shots shots_train = self.filter_by_number(shot_numbers_train) shots_test = self.filter_by_number(shot_numbers_test) return shots_train.random_sublist(use_shots_train),shots_test.random_sublist(use_shots_test)
def split_train_test(self, conf): # shot_list_dir = conf['paths']['shot_list_dir'] shot_files = conf['paths']['shot_files'] shot_files_test = conf['paths']['shot_files_test'] train_frac = conf['training']['train_frac'] shuffle_training = conf['training']['shuffle_training'] use_shots = conf['data']['use_shots'] all_signals = conf['paths']['all_signals'] # split "maximum number of shots to use" into: # test vs. (train U validate) use_shots_train = int(round(train_frac * use_shots)) use_shots_test = int(round((1 - train_frac) * use_shots)) if len(shot_files_test) == 0: # split randomly, e.g. sample both sets from same distribution # such as D3D test and train shot_list_train, shot_list_test = train_test_split( self.shots, train_frac, shuffle_training) # train and test list given, e.g. they are sampled from separate # distributions such as train=CW and test=ILW for JET else: shot_list_train = ShotList() shot_list_train.load_from_shot_list_files_objects( shot_files, all_signals) shot_list_test = ShotList() shot_list_test.load_from_shot_list_files_objects( shot_files_test, all_signals) shot_numbers_train = [shot.number for shot in shot_list_train] shot_numbers_test = [shot.number for shot in shot_list_test] # make sure we only use pre-filtered valid shots shots_train = self.filter_by_number(shot_numbers_train) shots_test = self.filter_by_number(shot_numbers_test) return shots_train.random_sublist( use_shots_train), shots_test.random_sublist(use_shots_test)
def split_direct(self,frac,do_shuffle=True): shot_list_one,shot_list_two = train_test_split(self.shots,frac,do_shuffle) return ShotList(shot_list_one),ShotList(shot_list_two)