示例#1
0
    def save_shot(self, shot, P_thresh_opt=0, extra_filename=''):
        if self.normalizer is None:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = (
                    self.conf['paths']['normalizer_path'])
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        shot.restore(self.shots_dir)
        # t_disrupt = shot.t_disrupt
        # is_disruptive = shot.is_disruptive
        self.normalizer.apply(shot)

        pred, truth, is_disr = self.get_pred_truth_disr_by_shot(shot)
        use_signals = self.saved_conf['paths']['use_signals']
        np.savez('sig_{}{}.npz'.format(shot.number, extra_filename),
                 shot=shot,
                 T_min_warn=self.T_min_warn,
                 T_max_warn=self.T_max_warn,
                 prediction=pred,
                 truth=truth,
                 use_signals=use_signals,
                 P_thresh=P_thresh_opt)
示例#2
0
    def plot_shot(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,P_thresh_opt=None,prediction_type='',extra_filename=''):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if(shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            t_disrupt = shot.t_disrupt
            is_disruptive =  shot.is_disruptive
            if normalize:
                self.normalizer.apply(shot)

            use_signals = self.saved_conf['paths']['use_signals']
            f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(13,13))#, squeeze=False)
            plt.title(prediction_type)
            assert(np.all(shot.ttd.flatten() == truth.flatten()))
            for i,sig in enumerate(use_signals):
                num_channels = sig.num_channels
                ax = axarr[i]
                sig_arr = shot.signals_dict[sig]
                if num_channels == 1:
                    ax.plot(sig_arr[:,0],label = sig.description)
                else:
                    ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description + " (profile)")
                    ax.set_ylim([0,num_channels])
                ax.legend(loc='best',fontsize=8)
                plt.setp(ax.get_xticklabels(),visible=False)
                plt.setp(ax.get_yticklabels(),fontsize=7)
                f.subplots_adjust(hspace=0)
                #print(sig)
                #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1] 
            if self.pred_ttd:
                ax.semilogy((-truth+0.0001),label='ground truth')
                ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
            else:
                ax.plot((truth+0.001),label='ground truth')
                ax.plot(prediction,'g',label='neural net prediction')
                ax.axhline(P_thresh_opt,color='k',label='trigger threshold')
            #ax.set_ylim([1e-5,1.1e0])
            ax.set_ylim([-2,2])
            if len(truth)-self.T_max_warn >= 0:
                ax.axvline(len(truth)-self.T_max_warn,color='r',label='min warning time')
            ax.axvline(len(truth)-self.T_min_warn,color='r',label='max warning time')
            ax.set_xlabel('T [ms]')
            #ax.legend(loc = 'lower left',fontsize=10)
            plt.setp(ax.get_yticklabels(),fontsize=7)
            # ax.grid()           
            if save_fig:
                plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
                np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
            plt.close()
        else:
            print("Shot hasn't been processed")
示例#3
0
# set PRNG seed, unique for each worker, based on MPI task index for
# reproducible shuffling in guranteed_preprocessed() and training steps
np.random.seed(g.task_index)
random.seed(g.task_index)

only_predict = len(sys.argv) > 1
custom_path = None
if only_predict:
    custom_path = sys.argv[1]
    g.print_unique("predicting using path {}".format(custom_path))

#####################################################
#                 NORMALIZATION                     #
#####################################################
normalizer = Normalizer(conf)
if g.task_index == 0:
    # make sure preprocessing has been run, and results are saved to files
    # if not, only master MPI rank spawns thread pool to perform preprocessing
    (shot_list_train, shot_list_validate,
     shot_list_test) = guarantee_preprocessed(conf)
    # similarly, train normalizer (if necessary) w/ master MPI rank only
    normalizer.train()  # verbose=False only suppresses if purely loading
g.comm.Barrier()
g.print_unique("begin preprocessor+normalization (all MPI ranks)...")
# second call has ALL MPI ranks load preprocessed shots from .npz files
(shot_list_train, shot_list_validate,
 shot_list_test) = guarantee_preprocessed(conf, verbose=True)
# second call to normalizer training
normalizer.conf['data']['recompute_normalization'] = False
normalizer.train(verbose=True)
    shot_list_tmp,t_range = create_shot_list_tmp(original_shot,time_points,sigs) 
    y_prime,y_gold,disruptive = mpi_make_predictions(conf,shot_list_tmp,loader,custom_path)
    shot_list_tmp.make_light()
    return t_range,get_importance_measure_given_y_prime(y_prime,metric),y_prime[-1]

def difference_metric(y_prime,y_prime_orig):
    idx = np.argmax(y_prime_orig) 
    return (np.max(y_prime_orig) - y_prime[idx])/(np.max(y_prime_orig) - np.min(y_prime_orig))

def get_importance_measure_given_y_prime(y_prime,metric):
    differences = [metric(y_prime[i],y_prime[-1]) for i in range(len(y_prime))]
    return 1.0-np.array(differences)#/np.max(differences)


print("normalization",end='')
normalizer = Normalizer(conf)
normalizer.train()
normalizer = ByShotAugmentator(normalizer)
loader = Loader(conf,normalizer)
print("...done")

# if not only_predict:
#     mpi_train(conf,shot_list_train,shot_list_validate,loader)

#load last model for testing
loader.set_inference_mode(True)
use_signals = copy.copy(conf['paths']['use_signals'])
use_signals.append(None)


示例#5
0
if only_predict:
    custom_path = sys.argv[2]
    print("predicting using path {}".format(custom_path))

#####################################################
####################Normalization####################
#####################################################
if task_index == 0:  #make sure preprocessing has been run, and is saved as a file
    shot_list_train, shot_list_validate, shot_list_test = guarantee_preprocessed(
        conf)
comm.Barrier()
shot_list_train, shot_list_validate, shot_list_test = guarantee_preprocessed(
    conf)

print("normalization", end='')
normalizer = Normalizer(conf)
normalizer.train()
loader = Loader(conf, normalizer)
print("...done")

#ensure training has a separate random seed for every worker
np.random.seed(task_index)
random.seed(task_index)
if not only_predict:
    mpi_train(conf, shot_list_train, shot_list_validate, loader)

#load last model for testing
loader.set_inference_mode(True)
print('saving results')
y_prime = []
y_gold = []
if only_predict:
    custom_path = sys.argv[1]
    print("predicting using path {}".format(custom_path))

#####################################################
####################Normalization####################
#####################################################
if task_index == 0:  #make sure preprocessing has been run, and is saved as a file
    shot_list_train, shot_list_validate, shot_list_test = guarantee_preprocessed(
        conf)
comm.Barrier()
shot_list_train, shot_list_validate, shot_list_test = guarantee_preprocessed(
    conf)

print("normalization", end='')
raw_normalizer = Normalizer(conf)
raw_normalizer.train()
is_inference = False
normalizer = Augmentator(raw_normalizer, is_inference, conf)
loader = Loader(conf, normalizer)
print("...done")

if not only_predict:
    mpi_train(conf, shot_list_train, shot_list_validate, loader)

#load last model for testing
print('saving results')
y_prime = []
y_gold = []
disruptive = []
示例#7
0
    def plot_shot(self,
                  shot,
                  save_fig=True,
                  normalize=True,
                  truth=None,
                  prediction=None,
                  P_thresh_opt=None,
                  prediction_type='',
                  extra_filename=''):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = (
                    self.conf['paths']['normalizer_path'])
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if (shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            if shot.signals_dict is not None:
                # make sure shot was saved with data
                # t_disrupt = shot.t_disrupt
                # is_disruptive = shot.is_disruptive
                if normalize:
                    self.normalizer.apply(shot)

                use_signals = self.saved_conf['paths']['use_signals']
                fontsize = 15
                lower_lim = 0  # len(pred)
                plt.close()
                # colors = ["b", "k"]
                # lss = ["-", "--"]
                f, axarr = plt.subplots(len(use_signals) + 1,
                                        1,
                                        sharex=True,
                                        figsize=(10, 15))
                plt.title(prediction_type)
                assert (np.all(shot.ttd.flatten() == truth.flatten()))
                xx = range(len(prediction))  # list(reversed(range(len(pred))))
                for i, sig in enumerate(use_signals):
                    ax = axarr[i]
                    num_channels = sig.num_channels
                    sig_arr = shot.signals_dict[sig]
                    if num_channels == 1:
                        ax.plot(xx, sig_arr[:, 0], linewidth=2)
                        ax.plot([], linestyle="none", label=sig.description)
                        if np.min(sig_arr[:, 0]) < 0:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        ax.plot([], linestyle="none", label=sig.description)
                        if np.min(sig_arr[:, 0]) < 0:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        else:
                            ax.set_ylim([0, 8])
                            ax.set_yticks([0, 5])
                    else:
                        ax.imshow(sig_arr[:, :].T,
                                  aspect='auto',
                                  label=sig.description,
                                  cmap="inferno")
                        ax.set_ylim([0, num_channels])
                        ax.text(lower_lim + 200,
                                45,
                                sig.description,
                                bbox={
                                    'facecolor': 'white',
                                    'pad': 10
                                },
                                fontsize=fontsize - 5)
                        ax.set_yticks([0, num_channels / 2])
                        ax.set_yticklabels(["0", "0.5"])
                        ax.set_ylabel("$\\rho$", size=fontsize)
                    ax.legend(loc="best",
                              labelspacing=0.1,
                              fontsize=fontsize,
                              frameon=False)
                    ax.axvline(len(truth) - self.T_min_warn,
                               color='r',
                               linewidth=0.5)
                    plt.setp(ax.get_xticklabels(), visible=False)
                    plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                    f.subplots_adjust(hspace=0)
                ax = axarr[-1]
                # ax.semilogy((-truth+0.0001),label='ground truth')
                # ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                # ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
                # nn = np.min(pred)
                ax.plot(xx, truth, 'g', label='target', linewidth=2)
                # ax.axhline(0.4,linestyle="--",color='k',label='threshold')
                ax.plot(xx, prediction, 'b', label='RNN output', linewidth=2)
                ax.axhline(P_thresh_opt,
                           linestyle="--",
                           color='k',
                           label='threshold')
                ax.set_ylim([-2, 2])
                ax.set_yticks([-1, 0, 1])
                # if len(truth)-T_max_warn >= 0:
                # ax.axvline(len(truth)-T_max_warn,color='r')#,label='max
                # warning time')
                # ,label='min warning time')
                ax.axvline(len(truth) - self.T_min_warn,
                           color='r',
                           linewidth=0.5)
                ax.set_xlabel('T [ms]', size=fontsize)
                # ax.axvline(2400)
                ax.legend(loc=(0.5, 0.7),
                          fontsize=fontsize - 5,
                          labelspacing=0.1,
                          frameon=False)
                plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                plt.setp(ax.get_xticklabels(), fontsize=fontsize)
                # plt.xlim(0,200)
                plt.xlim([lower_lim, len(truth)])
                #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
                if save_fig:
                    plt.savefig('sig_fig_{}{}.png'.format(
                        shot.number, extra_filename),
                                bbox_inches='tight')
                    np.savez('sig_{}{}.npz'.format(shot.number,
                                                   extra_filename),
                             shot=shot,
                             T_min_warn=self.T_min_warn,
                             T_max_warn=self.T_max_warn,
                             prediction=prediction,
                             truth=truth,
                             use_signals=use_signals,
                             P_thresh=P_thresh_opt)
                # plt.show()
        else:
            print("Shot hasn't been processed")
示例#8
0
#     batch_size = conf['training']['batch_size_large']

np.random.seed(0)
random.seed(0)
#####################################################
####################PREPROCESSING####################
#####################################################
shot_list_train, shot_list_validate, shot_list_test = guarantee_preprocessed(
    conf)

#####################################################
####################Normalization####################
#####################################################

print("normalization", end='')
nn = Normalizer(conf)
nn.train()
loader = Loader(conf, nn)
print("...done")
print('Training on {} shots, testing on {} shots'.format(
    len(shot_list_train), len(shot_list_test)))

#####################################################
######################TRAINING#######################
#####################################################
#train(conf,shot_list_train,loader)
p = old_mp.Process(target=train,
                   args=(conf, shot_list_train, shot_list_validate, loader))
p.start()
p.join()
示例#9
0
    custom_path = sys.argv[1]
    print("predicting using path {}".format(custom_path))

#####################################################
#                   PREPROCESSING                   #
#####################################################
# TODO(KGF): check tuple unpack
(shot_list_train, shot_list_validate,
 shot_list_test) = guarantee_preprocessed(conf)

#####################################################
#                   NORMALIZATION                   #
#####################################################

print("normalization", end='')
nn = Normalizer(conf)
nn.train()
loader = Loader(conf, nn)
print("...done")
print('Training on {} shots, testing on {} shots'.format(
    len(shot_list_train), len(shot_list_test)))


#####################################################
#                    TRAINING                       #
#####################################################
# train(conf,shot_list_train,loader)
if not only_predict:
    p = old_mp.Process(target=train,
                       args=(conf, shot_list_train,
                             shot_list_validate, loader, shot_list_test)
示例#10
0
if conf['data']['normalizer'] == 'minmax':
    from plasma.preprocessor.normalize import MinMaxNormalizer as Normalizer
elif conf['data']['normalizer'] == 'meanvar':
    from plasma.preprocessor.normalize import MeanVarNormalizer as Normalizer
elif conf['data']['normalizer'] == 'var':
    from plasma.preprocessor.normalize import VarNormalizer as Normalizer #performs !much better than minmaxnormalizer
elif conf['data']['normalizer'] == 'averagevar':
    from plasma.preprocessor.normalize import AveragingVarNormalizer as Normalizer #performs !much better than minmaxnormalizer
else:
    print('unkown normalizer. exiting')
    exit(1)

np.random.seed(1)

print("normalization",end='')
nn = Normalizer(conf)
nn.train()
loader = Loader(conf,nn)
shot_list_train,shot_list_validate,shot_list_test = loader.load_shotlists(conf)
print("...done")

print('Training on {} shots, testing on {} shots'.format(len(shot_list_train),len(shot_list_test)))
from plasma.models import runner

specific_runner = runner.HyperRunner(conf,loader,shot_list_train)

best_run, best_model = specific_runner.frnn_minimize(algo=tpe.suggest,max_evals=2,trials=Trials())
print (best_run)
print (best_model)
if only_predict:
    custom_path = sys.argv[1]
    print("predicting using path {}".format(custom_path))

#####################################################
####################Normalization####################
#####################################################
if task_index == 0: #make sure preprocessing has been run, and is saved as a file
    shot_list_train,shot_list_validate,shot_list_test = guarantee_preprocessed(conf)
comm.Barrier()
shot_list_train,shot_list_validate,shot_list_test = guarantee_preprocessed(conf)



print("normalization",end='')
raw_normalizer = Normalizer(conf)
raw_normalizer.train()
is_inference= False
normalizer = Augmentator(raw_normalizer,is_inference,conf)
loader = Loader(conf,normalizer)
print("...done")

if not only_predict:
    mpi_train(conf,shot_list_train,shot_list_validate,loader)

#load last model for testing
print('saving results')
y_prime = []
y_gold = []
disruptive= []