예제 #1
0
    def save_shot(self, shot, P_thresh_opt=0, extra_filename=''):
        if self.normalizer is None:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = (
                    self.conf['paths']['normalizer_path'])
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        shot.restore(self.shots_dir)
        # t_disrupt = shot.t_disrupt
        # is_disruptive = shot.is_disruptive
        self.normalizer.apply(shot)

        pred, truth, is_disr = self.get_pred_truth_disr_by_shot(shot)
        use_signals = self.saved_conf['paths']['use_signals']
        np.savez('sig_{}{}.npz'.format(shot.number, extra_filename),
                 shot=shot,
                 T_min_warn=self.T_min_warn,
                 T_max_warn=self.T_max_warn,
                 prediction=pred,
                 truth=truth,
                 use_signals=use_signals,
                 P_thresh=P_thresh_opt)
예제 #2
0
    def plot_shot_old(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,P_thresh_opt=None,prediction_type='',extra_filename=''):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if(shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            t_disrupt = shot.t_disrupt
            is_disruptive =  shot.is_disruptive
            if normalize:
                self.normalizer.apply(shot)

            use_signals = self.saved_conf['paths']['use_signals']
            f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(13,13))#, squeeze=False)
            plt.title(prediction_type)
            #all files must agree on T_warning due to output of truth vs. normalized shot ttd.
            assert(np.all(shot.ttd.flatten() == truth.flatten()))
            for i,sig in enumerate(use_signals):
                num_channels = sig.num_channels
                ax = axarr[i]
                sig_arr = shot.signals_dict[sig]
                if num_channels == 1:
                    ax.plot(sig_arr[:,0],label = sig.description)
                else:
                    ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description + " (profile)")
                    ax.set_ylim([0,num_channels])
                ax.legend(loc='best',fontsize=8)
                plt.setp(ax.get_xticklabels(),visible=False)
                plt.setp(ax.get_yticklabels(),fontsize=7)
                f.subplots_adjust(hspace=0)
                #print(sig)
                #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1] 
            if self.pred_ttd:
                ax.semilogy((-truth+0.0001),label='ground truth')
                ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
            else:
                ax.plot((truth+0.001),label='ground truth')
                ax.plot(prediction,'g',label='neural net prediction')
                ax.axhline(P_thresh_opt,color='k',label='trigger threshold')
            #ax.set_ylim([1e-5,1.1e0])
            ax.set_ylim([-2,2])
            if len(truth)-self.T_max_warn >= 0:
                ax.axvline(len(truth)-self.T_max_warn,color='r',label='min warning time')
            ax.axvline(len(truth)-self.T_min_warn,color='r',label='max warning time')
            ax.set_xlabel('T [ms]')
            #ax.legend(loc = 'lower left',fontsize=10)
            plt.setp(ax.get_yticklabels(),fontsize=7)
            # ax.grid()           
            if save_fig:
                plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
                np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
            plt.close()
        else:
            print("Shot hasn't been processed")
예제 #3
0
    def plot_shot(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,P_thresh_opt=None,prediction_type='',extra_filename=''):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if(shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            t_disrupt = shot.t_disrupt
            is_disruptive =  shot.is_disruptive
            if normalize:
                self.normalizer.apply(shot)

            use_signals = self.saved_conf['paths']['use_signals']
            f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(13,13))#, squeeze=False)
            plt.title(prediction_type)
            assert(np.all(shot.ttd.flatten() == truth.flatten()))
            for i,sig in enumerate(use_signals):
                num_channels = sig.num_channels
                ax = axarr[i]
                sig_arr = shot.signals_dict[sig]
                if num_channels == 1:
                    ax.plot(sig_arr[:,0],label = sig.description)
                else:
                    ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description + " (profile)")
                    ax.set_ylim([0,num_channels])
                ax.legend(loc='best',fontsize=8)
                plt.setp(ax.get_xticklabels(),visible=False)
                plt.setp(ax.get_yticklabels(),fontsize=7)
                f.subplots_adjust(hspace=0)
                #print(sig)
                #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1] 
            if self.pred_ttd:
                ax.semilogy((-truth+0.0001),label='ground truth')
                ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
            else:
                ax.plot((truth+0.001),label='ground truth')
                ax.plot(prediction,'g',label='neural net prediction')
                ax.axhline(P_thresh_opt,color='k',label='trigger threshold')
            #ax.set_ylim([1e-5,1.1e0])
            ax.set_ylim([-2,2])
            if len(truth)-self.T_max_warn >= 0:
                ax.axvline(len(truth)-self.T_max_warn,color='r',label='min warning time')
            ax.axvline(len(truth)-self.T_min_warn,color='r',label='max warning time')
            ax.set_xlabel('T [ms]')
            #ax.legend(loc = 'lower left',fontsize=10)
            plt.setp(ax.get_yticklabels(),fontsize=7)
            # ax.grid()           
            if save_fig:
                plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
                np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
            plt.close()
        else:
            print("Shot hasn't been processed")
예제 #4
0
    def save_shot(self,shot,P_thresh_opt = 0,extra_filename=''):
        if self.normalizer is None:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)
 
        shot.restore(self.shots_dir)
        t_disrupt = shot.t_disrupt
        is_disruptive =  shot.is_disruptive
        self.normalizer.apply(shot)



        pred,truth,is_disr = self.get_pred_truth_disr_by_shot(shot)
        use_signals = self.saved_conf['paths']['use_signals']
        np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=pred,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
예제 #5
0
    def plot_shot(self,
                  shot,
                  save_fig=True,
                  normalize=True,
                  truth=None,
                  prediction=None,
                  P_thresh_opt=None,
                  prediction_type='',
                  extra_filename=''):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = (
                    self.conf['paths']['normalizer_path'])
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if (shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            if shot.signals_dict is not None:
                # make sure shot was saved with data
                # t_disrupt = shot.t_disrupt
                # is_disruptive = shot.is_disruptive
                if normalize:
                    self.normalizer.apply(shot)

                use_signals = self.saved_conf['paths']['use_signals']
                fontsize = 15
                lower_lim = 0  # len(pred)
                plt.close()
                # colors = ["b", "k"]
                # lss = ["-", "--"]
                f, axarr = plt.subplots(len(use_signals) + 1,
                                        1,
                                        sharex=True,
                                        figsize=(10, 15))
                plt.title(prediction_type)
                assert (np.all(shot.ttd.flatten() == truth.flatten()))
                xx = range(len(prediction))  # list(reversed(range(len(pred))))
                for i, sig in enumerate(use_signals):
                    ax = axarr[i]
                    num_channels = sig.num_channels
                    sig_arr = shot.signals_dict[sig]
                    if num_channels == 1:
                        ax.plot(xx, sig_arr[:, 0], linewidth=2)
                        ax.plot([], linestyle="none", label=sig.description)
                        if np.min(sig_arr[:, 0]) < 0:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        ax.plot([], linestyle="none", label=sig.description)
                        if np.min(sig_arr[:, 0]) < 0:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        else:
                            ax.set_ylim([0, 8])
                            ax.set_yticks([0, 5])
                    else:
                        ax.imshow(sig_arr[:, :].T,
                                  aspect='auto',
                                  label=sig.description,
                                  cmap="inferno")
                        ax.set_ylim([0, num_channels])
                        ax.text(lower_lim + 200,
                                45,
                                sig.description,
                                bbox={
                                    'facecolor': 'white',
                                    'pad': 10
                                },
                                fontsize=fontsize - 5)
                        ax.set_yticks([0, num_channels / 2])
                        ax.set_yticklabels(["0", "0.5"])
                        ax.set_ylabel("$\\rho$", size=fontsize)
                    ax.legend(loc="best",
                              labelspacing=0.1,
                              fontsize=fontsize,
                              frameon=False)
                    ax.axvline(len(truth) - self.T_min_warn,
                               color='r',
                               linewidth=0.5)
                    plt.setp(ax.get_xticklabels(), visible=False)
                    plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                    f.subplots_adjust(hspace=0)
                ax = axarr[-1]
                # ax.semilogy((-truth+0.0001),label='ground truth')
                # ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                # ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
                # nn = np.min(pred)
                ax.plot(xx, truth, 'g', label='target', linewidth=2)
                # ax.axhline(0.4,linestyle="--",color='k',label='threshold')
                ax.plot(xx, prediction, 'b', label='RNN output', linewidth=2)
                ax.axhline(P_thresh_opt,
                           linestyle="--",
                           color='k',
                           label='threshold')
                ax.set_ylim([-2, 2])
                ax.set_yticks([-1, 0, 1])
                # if len(truth)-T_max_warn >= 0:
                # ax.axvline(len(truth)-T_max_warn,color='r')#,label='max
                # warning time')
                # ,label='min warning time')
                ax.axvline(len(truth) - self.T_min_warn,
                           color='r',
                           linewidth=0.5)
                ax.set_xlabel('T [ms]', size=fontsize)
                # ax.axvline(2400)
                ax.legend(loc=(0.5, 0.7),
                          fontsize=fontsize - 5,
                          labelspacing=0.1,
                          frameon=False)
                plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                plt.setp(ax.get_xticklabels(), fontsize=fontsize)
                # plt.xlim(0,200)
                plt.xlim([lower_lim, len(truth)])
                #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
                if save_fig:
                    plt.savefig('sig_fig_{}{}.png'.format(
                        shot.number, extra_filename),
                                bbox_inches='tight')
                    np.savez('sig_{}{}.npz'.format(shot.number,
                                                   extra_filename),
                             shot=shot,
                             T_min_warn=self.T_min_warn,
                             T_max_warn=self.T_max_warn,
                             prediction=prediction,
                             truth=truth,
                             use_signals=use_signals,
                             P_thresh=P_thresh_opt)
                # plt.show()
        else:
            print("Shot hasn't been processed")
예제 #6
0
 def plot_shot(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,P_thresh_opt=None,prediction_type='',extra_filename=''):
     if self.normalizer is None and normalize:
         if self.conf is not None:
             self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
         nn = Normalizer(self.saved_conf)
         nn.train()
         self.normalizer = nn
         self.normalizer.set_inference_mode(True)
 
     if(shot.previously_saved(self.shots_dir)):
         shot.restore(self.shots_dir)
         if shot.signals_dict is not None: #make sure shot was saved with data
             t_disrupt = shot.t_disrupt
             is_disruptive =  shot.is_disruptive
             if normalize:
                 self.normalizer.apply(shot)
 
             use_signals = self.saved_conf['paths']['use_signals']
             fontsize= 15
             lower_lim = 0 #len(pred)
             plt.close()
             colors = ["b","k"]
             lss = ["-","--"]
             f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(10,15))#, squeeze=False)
             plt.title(prediction_type)
             assert(np.all(shot.ttd.flatten() == truth.flatten()))
             xx = range(len(prediction)) #list(reversed(range(len(pred))))
             for i,sig in enumerate(use_signals):
                 ax = axarr[i]
                 num_channels = sig.num_channels
                 sig_arr = shot.signals_dict[sig]
                 if num_channels == 1:
     #                 if j == 0:
                     ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j])
     #                 else:
     #                     ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig])
                     ax.plot([],linestyle="none",label = sig.description)#labels[sig])
                     if np.min(sig_arr[:,0]) < 0:
                         ax.set_ylim([-6,6])
                         ax.set_yticks([-5,0,5])
     #                     ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig])
                     ax.plot([],linestyle="none",label = sig.description)#labels[sig])
                     if np.min(sig_arr[:,0]) < 0:
                         ax.set_ylim([-6,6])
                         ax.set_yticks([-5,0,5])
                     else:
                         ax.set_ylim([0,8])
                         ax.set_yticks([0,5])
     #                 ax.set_ylabel(labels[sig],size=fontsize)
                 else:
                     ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description,cmap="inferno" )
                     ax.set_ylim([0,num_channels])
                     ax.text(lower_lim+200, 45, sig.description, bbox={'facecolor': 'white', 'pad': 10},fontsize=fontsize-5)
                     ax.set_yticks([0,num_channels/2])
                     ax.set_yticklabels(["0","0.5"])
                     ax.set_ylabel("$\\rho$",size=fontsize)
                 ax.legend(loc="best",labelspacing=0.1,fontsize=fontsize,frameon=False)
                 ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)
                 plt.setp(ax.get_xticklabels(),visible=False)
                 plt.setp(ax.get_yticklabels(),fontsize=fontsize)
                 f.subplots_adjust(hspace=0)
                 #print(sig)
                 #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
             ax = axarr[-1] 
             #         ax.semilogy((-truth+0.0001),label='ground truth')
             #         ax.plot(-prediction+0.0001,'g',label='neural net prediction')
             #         ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
     #         nn = np.min(pred)
             ax.plot(xx,truth,'g',label='target',linewidth=2)
     #         ax.axhline(0.4,linestyle="--",color='k',label='threshold')
             ax.plot(xx,prediction,'b',label='RNN output',linewidth=2)
             ax.axhline(P_thresh_opt,linestyle="--",color='k',label='threshold')
             ax.set_ylim([-2,2])
             ax.set_yticks([-1,0,1])
             # if len(truth)-T_max_warn >= 0:
             #     ax.axvline(len(truth)-T_max_warn,color='r')#,label='max warning time')
             ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)#,label='min warning time')
             ax.set_xlabel('T [ms]',size=fontsize)
             # ax.axvline(2400)
             ax.legend(loc = (0.5,0.7),fontsize=fontsize-5,labelspacing=0.1,frameon=False)
             plt.setp(ax.get_yticklabels(),fontsize=fontsize)
             plt.setp(ax.get_xticklabels(),fontsize=fontsize)
             # plt.xlim(0,200)
             plt.xlim([lower_lim,len(truth)])
     #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
             if save_fig:
                 plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
                 np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
             #plt.show()
     else:
         print("Shot hasn't been processed")