def split_train_test(self,conf):
     shot_list_dir = conf['paths']['shot_list_dir']
     shot_files = conf['paths']['shot_files']
     shot_files_test = conf['paths']['shot_files_test']
     train_frac = conf['training']['train_frac']
     shuffle_training = conf['training']['shuffle_training']
     use_shots = conf['data']['use_shots']
     all_signals = conf['paths']['all_signals']
     #split randomly
     use_shots_train = int(round(train_frac*use_shots))
     use_shots_test = int(round((1-train_frac)*use_shots))
     if len(shot_files_test) == 0:
         shot_list_train,shot_list_test = train_test_split(self.shots,train_frac,shuffle_training)
     #train and test list given
     else:
         shot_list_train = ShotList()
         shot_list_train.load_from_shot_list_files_objects(shot_files,all_signals)
             
         shot_list_test = ShotList()
         shot_list_test.load_from_shot_list_files_objects(shot_files_test,all_signals)
     
     
     shot_numbers_train = [shot.number for shot in shot_list_train]
     shot_numbers_test = [shot.number for shot in shot_list_test]
     print(len(shot_numbers_train),len(shot_numbers_test))
     #make sure we only use pre-filtered valid shots
     shots_train = self.filter_by_number(shot_numbers_train)
     shots_test = self.filter_by_number(shot_numbers_test)
     return shots_train.random_sublist(use_shots_train),shots_test.random_sublist(use_shots_test)
Exemple #2
0
 def split_train_test(self,conf):
     shot_list_dir = conf['paths']['shot_list_dir']
     shot_files = conf['paths']['shot_files']
     shot_files_test = conf['paths']['shot_files_test']
     train_frac = conf['training']['train_frac']
     shuffle_training = conf['training']['shuffle_training']
     use_shots = conf['data']['use_shots']
     all_signals = conf['paths']['all_signals']
     #split randomly
     use_shots_train = int(round(train_frac*use_shots))
     use_shots_test = int(round((1-train_frac)*use_shots))
     if len(shot_files_test) == 0:
         shot_list_train,shot_list_test = train_test_split(self.shots,train_frac,shuffle_training)
     #train and test list given
     else:
         shot_list_train = ShotList()
         shot_list_train.load_from_shot_list_files_objects(shot_files,all_signals)
             
         shot_list_test = ShotList()
         shot_list_test.load_from_shot_list_files_objects(shot_files_test,all_signals)
     
     
     shot_numbers_train = [shot.number for shot in shot_list_train]
     shot_numbers_test = [shot.number for shot in shot_list_test]
     print(len(shot_numbers_train),len(shot_numbers_test))
     #make sure we only use pre-filtered valid shots
     shots_train = self.filter_by_number(shot_numbers_train)
     shots_test = self.filter_by_number(shot_numbers_test)
     return shots_train.random_sublist(use_shots_train),shots_test.random_sublist(use_shots_test)
    def split_train_test(self, conf):
        # shot_list_dir = conf['paths']['shot_list_dir']
        shot_files = conf['paths']['shot_files']
        shot_files_test = conf['paths']['shot_files_test']
        train_frac = conf['training']['train_frac']
        shuffle_training = conf['training']['shuffle_training']
        use_shots = conf['data']['use_shots']
        all_signals = conf['paths']['all_signals']
        # split "maximum number of shots to use" into:
        # test vs. (train U validate)
        use_shots_train = int(round(train_frac * use_shots))
        use_shots_test = int(round((1 - train_frac) * use_shots))
        if len(shot_files_test) == 0:
            # split randomly, e.g. sample both sets from same distribution
            # such as D3D test and train
            shot_list_train, shot_list_test = train_test_split(
                self.shots, train_frac, shuffle_training)
        # train and test list given, e.g. they are sampled from separate
        # distributions such as train=CW and test=ILW for JET
        else:
            shot_list_train = ShotList()
            shot_list_train.load_from_shot_list_files_objects(
                shot_files, all_signals)

            shot_list_test = ShotList()
            shot_list_test.load_from_shot_list_files_objects(
                shot_files_test, all_signals)

        shot_numbers_train = [shot.number for shot in shot_list_train]
        shot_numbers_test = [shot.number for shot in shot_list_test]
        # make sure we only use pre-filtered valid shots
        shots_train = self.filter_by_number(shot_numbers_train)
        shots_test = self.filter_by_number(shot_numbers_test)
        return shots_train.random_sublist(
            use_shots_train), shots_test.random_sublist(use_shots_test)
 def split_direct(self,frac,do_shuffle=True):
     shot_list_one,shot_list_two = train_test_split(self.shots,frac,do_shuffle)
     return ShotList(shot_list_one),ShotList(shot_list_two)
Exemple #5
0
 def split_direct(self,frac,do_shuffle=True):
     shot_list_one,shot_list_two = train_test_split(self.shots,frac,do_shuffle)
     return ShotList(shot_list_one),ShotList(shot_list_two)