Exemplo n.º 1
0
    def save_shot(self, shot, P_thresh_opt=0, extra_filename=""):
        if self.normalizer is None:
            if self.conf is not None:
                self.saved_conf["paths"]["normalizer_path"] = self.conf[
                    "paths"]["normalizer_path"]
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        shot.restore(self.shots_dir)
        # t_disrupt = shot.t_disrupt
        # is_disruptive = shot.is_disruptive
        self.normalizer.apply(shot)

        pred, truth, is_disr = self.get_pred_truth_disr_by_shot(shot)
        use_signals = self.saved_conf["paths"]["use_signals"]
        np.savez(
            "sig_{}{}.npz".format(shot.number, extra_filename),
            shot=shot,
            T_min_warn=self.T_min_warn,
            T_max_warn=self.T_max_warn,
            prediction=pred,
            truth=truth,
            use_signals=use_signals,
            P_thresh=P_thresh_opt,
        )
Exemplo n.º 2
0
    def plot_shot_old(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,P_thresh_opt=None,prediction_type='',extra_filename=''):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if(shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            t_disrupt = shot.t_disrupt
            is_disruptive =  shot.is_disruptive
            if normalize:
                self.normalizer.apply(shot)

            use_signals = self.saved_conf['paths']['use_signals']
            f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(13,13))#, squeeze=False)
            #plt.title(prediction_type)
            #all files must agree on T_warning due to output of truth vs. normalized shot ttd.
            assert(np.all(shot.ttd.flatten() == truth.flatten()))
            for i,sig in enumerate(use_signals):
                num_channels = sig.num_channels
                ax = axarr[i]
                sig_arr = shot.signals_dict[sig]
                if num_channels == 1:
                    ax.plot(sig_arr[:,0],label = sig.description)
                else:
                    ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description + " (profile)")
                    ax.set_ylim([0,num_channels])
                ax.legend(loc='best',fontsize=8)
                plt.setp(ax.get_xticklabels(),visible=False)
                plt.setp(ax.get_yticklabels(),fontsize=7)
                f.subplots_adjust(hspace=0)
                #print(sig)
                #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1] 
            if self.pred_ttd:
                ax.semilogy((-truth+0.0001),label='ground truth')
                ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
            else:
                ax.plot((truth+0.001),label='ground truth')
                ax.plot(prediction,'g',label='neural net prediction')
                ax.axhline(P_thresh_opt,color='k',label='trigger threshold')
            #ax.set_ylim([1e-5,1.1e0])
            ax.set_ylim([-2,2])
            if len(truth)-self.T_max_warn >= 0:
                ax.axvline(len(truth)-self.T_max_warn,color='r',label='min warning time')
            ax.axvline(len(truth)-self.T_min_warn,color='r',label='max warning time')
            ax.set_xlabel('T [ms]')
            #ax.legend(loc = 'lower left',fontsize=10)
            plt.setp(ax.get_yticklabels(),fontsize=7)
            # ax.grid()           
            if save_fig:
                plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
               # np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
            plt.close()
        else:
            print("Shot hasn't been processed")
Exemplo n.º 3
0
    def plot_shot(
        self,
        shot,
        save_fig=True,
        normalize=True,
        truth=None,
        prediction=None,
        P_thresh_opt=None,
        prediction_type="",
        extra_filename="",
    ):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf["paths"]["normalizer_path"] = self.conf[
                    "paths"]["normalizer_path"]
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if shot.previously_saved(self.shots_dir):
            shot.restore(self.shots_dir)
            if shot.signals_dict is not None:  # make sure shot was saved with data
                # t_disrupt = shot.t_disrupt
                # is_disruptive = shot.is_disruptive
                if normalize:
                    self.normalizer.apply(shot)

                use_signals = self.saved_conf["paths"]["use_signals"]
                fontsize = 18
                lower_lim = 0  # len(pred)
                plt.close()
                colors = ["b", "green", "red", "c", "m", "orange", "k", "y"]
                # lss = ["-", "--"]
                f, axarr = plt.subplots(
                    4 + 1, 1, sharex=True,
                    figsize=(18, 15))  # ,squeeze=False)#, squeeze=False)
                plt.title(prediction_type)
                assert np.all(shot.ttd.flatten() == truth.flatten())
                xx = range(len(prediction))
                j = 0  # list(reversed(range(len(pred))))
                j1 = 0
                p0 = 0
                for k, p in enumerate(prediction):
                    if p > P_thresh_opt:
                        p0 = k
                        break

                for i, sig in enumerate(use_signals):
                    num_channels = sig.num_channels
                    sig_arr = shot.signals_dict[sig]
                    # legend = []
                    if num_channels == 1:
                        j = i // 7
                        ax = axarr[j]
                        #                 if j == 0:
                        ax.plot(
                            xx,
                            sig_arr[:, 0],
                            linewidth=2,
                            color=colors[i % 7],
                            label=sig.description,
                        )  # ,linestyle=lss[j],color=colors[j])
                        if np.min(sig_arr[:, 0]) < -100000:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        else:
                            ax.set_ylim([-2, 11])
                            ax.set_yticks([0, 5, 10])
                        #                 ax.set_ylabel(labels[sig],size=fontsize)
                        ax.legend()
                    else:
                        j = -2 - j1
                        j1 += 1
                        ax = axarr[j]
                        ax.imshow(
                            sig_arr[:, :].T,
                            aspect="auto",
                            label=sig.description,
                            cmap="inferno",
                        )
                        ax.set_ylim([0, num_channels])
                        ax.text(
                            lower_lim + 200,
                            45,
                            sig.description,
                            bbox={
                                "facecolor": "white",
                                "pad": 10
                            },
                            fontsize=fontsize,
                        )
                        ax.set_yticks([0, num_channels / 2])
                        ax.set_yticklabels(["0", "0.5"])
                        ax.set_ylabel("$\\rho$", size=fontsize)
                    ax.legend(
                        loc="center left",
                        labelspacing=0.1,
                        bbox_to_anchor=(1, 0.5),
                        fontsize=fontsize,
                        frameon=False,
                    )
                    ax.axvline(len(truth) - self.T_min_warn, color="r")
                    ax.axvline(p0, linestyle="--", color="darkgrey")
                    plt.setp(ax.get_xticklabels(), visible=False)
                    plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                    f.subplots_adjust(hspace=0)
                    # print(sig)
                    # print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1]
                #         ax.semilogy((-truth+0.0001),label='ground truth')
                #         ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                #         ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
                #         nn = np.min(pred)
                #        ax.plot(xx,truth,'g',label='target',linewidth=2)
                #         ax.axhline(0.4,linestyle="--",color='k',label='threshold')

                ax.plot(
                    xx,
                    prediction,
                    "b",
                    label="RNN output--Disruption score",
                    linewidth=2,
                    zorder=1,
                )
                ax.axhline(P_thresh_opt,
                           linestyle="--",
                           color="k",
                           label="threshold",
                           zorder=2)
                ax.axvline(p0, linestyle="--", color="darkgrey")
                ax.set_ylim([min(prediction), max(prediction)])
                ax.set_yticks([0, 1])
                if p0 > 0:
                    ax.scatter(xx[k],
                               p,
                               s=300,
                               marker="*",
                               color="r",
                               zorder=3)

                ax.axvline(len(truth) - self.T_min_warn,
                           color="r",
                           linewidth=0.5)  # ,label='min warning time')
                ax.set_xlabel("T [ms]", size=fontsize)
                # ax.axvline(2400)
                ax.legend(
                    loc="center left",
                    labelspacing=0.1,
                    bbox_to_anchor=(1, 0.5),
                    fontsize=fontsize + 2,
                    frameon=False,
                )
                plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                plt.setp(ax.get_xticklabels(), fontsize=fontsize)
                # plt.xlim(0,200)
                plt.xlim([lower_lim, len(truth)])
                #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
                if save_fig:
                    plt.savefig(
                        "sig_fig_{}{}.png".format(shot.number, extra_filename),
                        bbox_inches="tight",
                    )
                    np.savez(
                        "sig_{}{}.npz".format(shot.number, extra_filename),
                        shot=shot,
                        T_min_warn=self.T_min_warn,
                        T_max_warn=self.T_max_warn,
                        prediction=prediction,
                        truth=truth,
                        use_signals=use_signals,
                        P_thresh=P_thresh_opt,
                    )
                # plt.show()
        else:
            print("Shot hasn't been processed")
    def plot_shot(
        self,
        shot,
        save_fig=True,
        normalize=True,
        truth=None,
        prediction=None,
        P_thresh_opt=None,
        prediction_type="",
        extra_filename="",
    ):
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf["paths"]["normalizer_path"] = self.conf[
                    "paths"]["normalizer_path"]
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)

        if shot.previously_saved(self.shots_dir):
            shot.restore(self.shots_dir)
            if shot.signals_dict is not None:
                # make sure shot was saved with data
                # t_disrupt = shot.t_disrupt
                # is_disruptive = shot.is_disruptive
                if normalize:
                    self.normalizer.apply(shot)

                use_signals = self.saved_conf["paths"]["use_signals"]
                fontsize = 15
                lower_lim = 0  # len(pred)
                plt.close()
                # colors = ["b", "k"]
                # lss = ["-", "--"]
                f, axarr = plt.subplots(len(use_signals) + 1,
                                        1,
                                        sharex=True,
                                        figsize=(10, 15))
                plt.title(prediction_type)
                assert np.all(shot.ttd.flatten() == truth.flatten())
                xx = range(len(prediction))  # list(reversed(range(len(pred))))
                for i, sig in enumerate(use_signals):
                    ax = axarr[i]
                    num_channels = sig.num_channels
                    sig_arr = shot.signals_dict[sig]
                    if num_channels == 1:
                        ax.plot(xx, sig_arr[:, 0], linewidth=2)
                        ax.plot([], linestyle="none", label=sig.description)
                        if np.min(sig_arr[:, 0]) < 0:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        ax.plot([], linestyle="none", label=sig.description)
                        if np.min(sig_arr[:, 0]) < 0:
                            ax.set_ylim([-6, 6])
                            ax.set_yticks([-5, 0, 5])
                        else:
                            ax.set_ylim([0, 8])
                            ax.set_yticks([0, 5])
                    else:
                        ax.imshow(
                            sig_arr[:, :].T,
                            aspect="auto",
                            label=sig.description,
                            cmap="inferno",
                        )
                        ax.set_ylim([0, num_channels])
                        ax.text(
                            lower_lim + 200,
                            45,
                            sig.description,
                            bbox={
                                "facecolor": "white",
                                "pad": 10
                            },
                            fontsize=fontsize - 5,
                        )
                        ax.set_yticks([0, num_channels / 2])
                        ax.set_yticklabels(["0", "0.5"])
                        ax.set_ylabel("$\\rho$", size=fontsize)
                    ax.legend(loc="best",
                              labelspacing=0.1,
                              fontsize=fontsize,
                              frameon=False)
                    ax.axvline(len(truth) - self.T_min_warn,
                               color="r",
                               linewidth=0.5)
                    plt.setp(ax.get_xticklabels(), visible=False)
                    plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                    f.subplots_adjust(hspace=0)
                ax = axarr[-1]
                # ax.semilogy((-truth+0.0001),label='ground truth')
                # ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                # ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
                # nn = np.min(pred)
                ax.plot(xx, truth, "g", label="target", linewidth=2)
                # ax.axhline(0.4,linestyle="--",color='k',label='threshold')
                ax.plot(xx, prediction, "b", label="RNN output", linewidth=2)
                ax.axhline(P_thresh_opt,
                           linestyle="--",
                           color="k",
                           label="threshold")
                ax.set_ylim([-2, 2])
                ax.set_yticks([-1, 0, 1])
                # if len(truth)-T_max_warn >= 0:
                # ax.axvline(len(truth)-T_max_warn,color='r')#,label='max
                # warning time')
                # ,label='min warning time')
                ax.axvline(len(truth) - self.T_min_warn,
                           color="r",
                           linewidth=0.5)
                ax.set_xlabel("T [ms]", size=fontsize)
                # ax.axvline(2400)
                ax.legend(
                    loc=(0.5, 0.7),
                    fontsize=fontsize - 5,
                    labelspacing=0.1,
                    frameon=False,
                )
                plt.setp(ax.get_yticklabels(), fontsize=fontsize)
                plt.setp(ax.get_xticklabels(), fontsize=fontsize)
                # plt.xlim(0,200)
                plt.xlim([lower_lim, len(truth)])
                #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
                if save_fig:
                    plt.savefig("sig_fig_{}{}.png".format(
                        shot.number, extra_filename)),
                    #        bbox_inches='tight')
                    np.savez(
                        "sig_{}{}.npz".format(shot.number, extra_filename),
                        shot=shot,
                        T_min_warn=self.T_min_warn,
                        T_max_warn=self.T_max_warn,
                        prediction=prediction,
                        truth=truth,
                        use_signals=use_signals,
                        P_thresh=P_thresh_opt,
                    )
                # plt.show()
        else:
            print("Shot hasn't been processed")
Exemplo n.º 5
0
    def plot_shot(self,shot,save_fig=True,normalize=True,truth=None,prediction=None,P_thresh_opt=None,prediction_type='',extra_filename=''):
        print('plotting shot,',shot,prediction_type,prediction.shape)
        if self.normalizer is None and normalize:
            if self.conf is not None:
                self.saved_conf['paths']['normalizer_path'] = self.conf['paths']['normalizer_path']
            nn = Normalizer(self.saved_conf)
            nn.train()
            self.normalizer = nn
            self.normalizer.set_inference_mode(True)
    
        if(shot.previously_saved(self.shots_dir)):
            shot.restore(self.shots_dir)
            if shot.signals_dict is not None: #make sure shot was saved with data
                t_disrupt = shot.t_disrupt
                is_disruptive =  shot.is_disruptive
                if normalize:
                    self.normalizer.apply(shot)
    
                use_signals = self.saved_conf['paths']['use_signals']
                all_signals = self.saved_conf['paths']['all_signals']
                fontsize= 18
                lower_lim = 0 #len(pred)
                plt.close()
                colors = ['b','green','red','c','m','orange','k','y']
                lss = ["-","--"]
                #f,axarr = plt.subplots(len(use_signals)+1,1,sharex=True,figsize=(10,15))#, squeeze=False)
                f,axarr = plt.subplots(4+1,1,sharex=True,figsize=(18,18))#,squeeze=False)#, squeeze=False)
                #plt.title(prediction_type)
                #assert(np.all(shot.ttd.flatten() == truth.flatten()))
                xx = range((prediction.shape[0]))
                j=0 #list(reversed(range(len(pred))))
                j1=0
                p0=0
                for i,sig_target in enumerate(all_signals):
                    if sig_target.description== 'n1 finite frequency signals': 
#'Locked mode amplitude':
                       target_plot=shot.signals_dict[sig_target]##[:,0]
                       target_plot=target_plot[:,0]
                       print(target_plot.shape)
                    elif sig_target.description== 'Locked mode amplitude':
                       lm_plot=shot.signals_dict[sig_target]##[:,0]
                       lm_plot=lm_plot[:,0]
                for i,sig in enumerate(use_signals):
                    num_channels = sig.num_channels
                    sig_arr = shot.signals_dict[sig]
                    legend=[]
                    if num_channels == 1:
                        j=i//7
                        ax = axarr[j]
        #                 if j == 0:
                        ax.plot(xx,sig_arr[:,0],linewidth=2,color=colors[i%7],label=sig.description)#,linestyle=lss[j],color=colors[j])
        #                 else:
        #                     ax.plot(xx,sig_arr[:,0],linewidth=2)#,linestyle=lss[j],color=colors[j],label = labels[sig])
                        #if np.min(sig_arr[:,0]) < -100000:
                        if j==0:
                          ax.set_ylim([-3,15])
                          ax.set_yticks([0,5,10])
                        else:
                          ax.set_ylim([-15,15])
                          ax.set_yticks([-10,-5,0,5,10])
        #                 ax.set_ylabel(labels[sig],size=fontsize)
                        ax.legend()
                    else:
                        j=-2-j1
                        j1+=1
                        ax = axarr[j]
                        ax.imshow(sig_arr[:,:].T, aspect='auto', label = sig.description,cmap="inferno" )
                        ax.set_ylim([0,num_channels])
                        ax.text(lower_lim+200, 45, sig.description, bbox={'facecolor': 'white', 'pad': 10},fontsize=fontsize)
                        ax.set_yticks([0,num_channels/2])
                        ax.set_yticklabels(["0","0.5"])
                        ax.set_ylabel("$\\rho$",size=fontsize)
                    ax.legend(loc="center left",labelspacing=0.1,bbox_to_anchor=(1,0.5),fontsize=fontsize,frameon=False)
                   # ax.axvline(len(truth)-self.T_min_warn,color='r')
                   # ax.axvline(p0,linestyle='--',color='darkgrey')
                    plt.setp(ax.get_xticklabels(),visible=False)
                    plt.setp(ax.get_yticklabels(),fontsize=fontsize)
                    f.subplots_adjust(hspace=0)
                    #print(sig)
                    #print('min: {}, max: {}'.format(np.min(sig_arr), np.max(sig_arr)))
                ax = axarr[-1] 
                #         ax.semilogy((-truth+0.0001),label='ground truth')
                #         ax.plot(-prediction+0.0001,'g',label='neural net prediction')
                #         ax.axhline(-P_thresh_opt,color='k',label='trigger threshold')
        #         nn = np.min(pred)
        #        ax.plot(xx,truth,'g',label='target',linewidth=2)
        #         ax.axhline(0.4,linestyle="--",color='k',label='threshold')
                print('predictions shape:',prediction.shape)
                print('truth shape:',truth.shape)
                
     #           prediction=prediction[:,0]
                prediction=prediction#-1.5
                prediction[prediction<0]=0.0
                minii,maxii= np.amin(prediction),np.amax(prediction)
                lm_plot_max=np.amax(lm_plot)
                lm_plot=lm_plot/lm_plot_max*maxii
                truth_plot_max=np.amax(truth[:,1])
                truth_plot=truth[:,1]/truth_plot_max*maxii

                print('******************************************************')
                print('******************************************************')
                print('******************************************************')
                print('Truth_plot',truth_plot[:-10])
                print('lm_plot',lm_plot[:-10])
                print('******************************************************')
                target_plot_max=np.amax(target_plot)
                target_plot=target_plot/target_plot_max*maxii
                ax.plot(xx,truth_plot,'yellow',label='truth')
                ax.plot(xx,lm_plot,'pink',label='Locked mode amplitude')
                ax.plot(xx,target_plot,'cyan',label='n1rms')
                #ax.plot(xx,truth,'pink',label='target')
                ax.plot(xx,prediction[:,0],'blue',label='FRNN-U predicted n=1 mode ',linewidth=2)
                ax.plot(xx,prediction[:,1],'red',label='FRNN-U predicted locked mode ',linewidth=2)
                #ax.axhline(P_thresh_opt,linestyle="--",color='k',label='threshold',zorder=2)
                #ax.axvline(p0,linestyle='--',color='darkgrey')
    #            ax.set_ylim([np.amin(prediction,truth),np.amax(prediction,truth)])
                ax.set_ylim([0,maxii])
                print('predictions:',shot.number,prediction)
#                ax.set_ylim([np.min([prediction,target_plot,lm_plot]),np.max([prediction,target_plot,lm_plot])])
                #ax.set_yticks([0,1])
                #if p0>0:
                #  ax.scatter(xx[k],p,s=300,marker='*',color='r',zorder=3)
                
                # if len(truth)-T_max_warn >= 0:
                #     ax.axvline(len(truth)-T_max_warn,color='r')#,label='max warning time')
        #        ax.axvline(len(truth)-self.T_min_warn,color='r',linewidth=0.5)#,label='min warning time')
                ax.set_xlabel('T [ms]',size=fontsize)
                # ax.axvline(2400)
                ax.legend(loc="center left",labelspacing=0.1,bbox_to_anchor=(1,0.5),fontsize=fontsize+2,frameon=False)
                plt.setp(ax.get_yticklabels(),fontsize=fontsize)
                plt.setp(ax.get_xticklabels(),fontsize=fontsize)
                # plt.xlim(0,200)
                plt.xlim([lower_lim,len(truth)])
        #         plt.savefig("{}.png".format(num),dpi=200,bbox_inches="tight")
                if save_fig:
                    plt.savefig('sig_fig_{}{}.png'.format(shot.number,extra_filename),bbox_inches='tight')
                    #np.savez('sig_{}{}.npz'.format(shot.number,extra_filename),shot=shot,T_min_warn=self.T_min_warn,T_max_warn=self.T_max_warn,prediction=prediction,truth=truth,use_signals=use_signals,P_thresh=P_thresh_opt)
                #plt.show()
        else:
            print("Shot hasn't been processed")