Пример #1
0
def vanilla_CGAN_step(netD,
                      netG,
                      optD,
                      optG,
                      loss_func,
                      device,
                      X_i=None,
                      C_i=None,
                      Z_i=None,
                      real_label=1,
                      fake_label=0,
                      n_D=1):
    # Train discriminator n_D times per batch
    for j in range(n_D):
        # Train on all-real batch
        netD.zero_grad()
        real_sample = torch.cat((X_i, C_i), axis=1)
        b_size = X_i.size(0)
        label = torch.full((b_size, ),
                           real_label,
                           device=device,
                           dtype=torch.get_default_dtype())
        output = netD(real_sample)
        errD_real = loss_func(output.view(-1), label.view(-1))
        errD_real.backward()
        D_x = output.mean().item()

        # Train on all-fake batch
        input_G = input_sample(b_size, C=C_i, device=device)
        fake = torch.cat((netG(input_G).detach(), C_i), axis=1)
        label.fill_(fake_label)
        output = netD(fake)
        errD_fake = loss_func(output.view(-1), label.view(-1))
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optD.step()

    # Train generator
    netG.zero_grad()

    # Create fake sample
    input_G = input_sample(b_size, C=C_i, device=device)
    fake = torch.cat((netG(input_G), C_i), axis=1)
    output = netD(fake)
    errG = -torch.log(output).mean()  # 'Logtrick'
    errG.backward()
    D_G_z2 = output.mean().item()
    optG.step()

    return errD, errG, D_x, D_G_z1, D_G_z2
Пример #2
0
    def QQ_plots(self,C=None,n_plot=None,G=None,params=None,results_path=None,save=False):
        '''
        Method to create QQ plots of generated data, test data and train data. 
        '''
        if (not save) and self.save_all_figs:
            save = True
        if params is None:
            params = self.params
        if results_path is None:
            results_path = self.results_path
        if G is None:
            G = self.G
        if (C is None) and self.CGAN:
            C = self.C_test
            assert G.c_dim > 0, 'Generator must be a conditional GAN if CGAN is toggled in the dataset.'
        elif self.CGAN:
            assert G.c_dim > 0, 'Generator must be a conditional GAN if CGAN is toggled in the dataset.'
            assert str(C.keys()) == str(self.C_test.keys()), 'The tensor specified must include all conditional parameters, in the same order as C_test.'
        else:
            # Vanilla GAN case
            C = dict()    
        if n_plot is None:
            if self.format == 'pdf':
                n_plot = 1000
            else:
                n_plot = self.N_test
        # Update parameters accordingly. If vanilla GAN, params remain unchanged 
        params = {**params,**C}

        if self.CGAN:
            C_test_tensor = make_test_tensor(C,n_plot,device=self.device)
            in_sample = input_sample(n_plot,C=C_test_tensor,device=self.device)
        else:
            in_sample = input_sample(n_plot,C=None,device=self.device)

        gendata = postprocess(G(in_sample).detach(),params['S0'],proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()

        if self.SDE == 'GBM':
            mu = params['mu']
            sigma = params['sigma']
            S0 = params['S0']
            t = params['t']

            scale = np.exp(np.log(S0)+(mu-0.5*sigma**2)*(t))
            s = sigma*np.sqrt(t)

            dist = stat.lognorm
            sparams = (s,0,scale)
        elif self.SDE == 'CIR':
            kappa = params['kappa']
            gamma = params['gamma']
            S_bar = params['S_bar']
            S0 = params['S0']
            s = params['s']
            t = params['t']
                
            kappa_bar = (4*kappa*S0*np.exp(-kappa*(t-s)))/(gamma**2*(1-np.exp(-kappa*(t-s))))
            c_bar = (gamma**2)/(4*kappa)*(1-np.exp(-kappa*(t-s)))    
            delta = (4*kappa*S_bar)/(gamma**2)

            dist = stat.ncx2
            sparams = (delta,kappa_bar,0,c_bar)
        else:
            raise Exception('SDE type not supported or understood.')

        # Test set 
        plt.figure(dpi=100)
        stat.probplot(x=self.sample_exact(N=n_plot,params=params).view(-1).cpu().numpy(),dist=dist,sparams=sparams,plot=plt)
        plt.title('')        
        if save:
            plt.savefig(results_path+'_QQ_test.'+self.format,format=self.format)
        else:
            plt.show()
        plt.close()

        # Generated
        plt.figure(dpi=100)
        stat.probplot(x=gendata,dist=dist,sparams=sparams,plot=plt)
        plt.title('')
        if save:
            plt.savefig(results_path+'_QQ_generated.'+self.format,format=self.format)
        else:
            plt.show()
        plt.close()

        if save:
            print('Saved QQ plot of exact variates and generated data in folder %s'%results_path)
Пример #3
0
    def save_iter_plot(self,iteration,filename=None,C=None,G=None,D=None,params=None,save_conf=True):
        '''
        Plots a kde plot and the confidence of the discriminator D(x). Saves the plot in `filename'. 
        '''
        if params is None:
            params = self.params.copy()
        if G is None:
            G = self.G
        if D is None:
            D = self.D
        if (C is None) and self.CGAN:
            C = self.C_test
        if filename is None:
            filename = os.path.join(self.results_path,'plot_iter_%02d.'%iteration + self.format)
        if self.CGAN:
            assert G.c_dim > 0, 'Generator must be a conditional GAN if CGAN is toggled in the dataset.'
            assert str(C.keys()) == str(self.C_test.keys()), 'The tensor specified must include all conditional parameters, in the same order as C_test.'
        else:
            # Vanilla GAN case
            C = dict()

        # Update the relevant parameters with C
        params = {**params,**C}

        #------------------------------------------------------
        # Compute inputs
        #------------------------------------------------------

        if self.CGAN:
            C_test_tensor = make_test_tensor(C,self.N_test,device=self.device)
            in_sample = input_sample(self.N_test,C=C_test_tensor,Z=self.fixed_noise.to(self.device),device=self.device)
        else:
            in_sample = input_sample(self.N_test,C=None,Z=self.fixed_noise,device=self.device)

        output = G(in_sample).detach()
        gendata = postprocess(output.view(-1),params['S0'],proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()

        # Instantiate function for pdf of analytical distribution, add 1e-6 to keep the fraction X_next/X_prev finite
        exact_raw = preprocess(self.sample_exact(N=self.N_test,params=params),torch.tensor(params['S0'],\
                        dtype=torch.float32).view(-1,1),proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=torch.device('cpu'),dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()
        
        output = output.view(-1).cpu().numpy()
        
        # Define domain based on GAN output 
        a1 = np.min(output)-0.1*np.abs(output.min())
        b1 = np.min(exact_raw)-0.1*np.abs(exact_raw.min())
        a2 = np.max(output)+0.1*np.abs(output.max())
        b2 = np.max(exact_raw)+0.1*np.abs(exact_raw.max())
        a = np.min((a1,b1))
        b = np.max((a2,b2))
        if a == 0:
            a -= 1e-20        

        if not self.supervised:

            # Define grid for KDE to be computed on 
            x_opt = np.linspace(a,b,1000)

            # Compute exact density p* and generator density p_th

            if self.proc_type is None:
                # Use exact pdf for p* if no pre-processing is used
                p_star = self.get_exact_pdf(params)(x_opt)
            else:
                # Otherwise use kernel estimate to compute p*
                p_star = FFTKDE(kernel='gaussian',bw='silverman').fit(exact_raw).evaluate(x_opt)

            kde_th = FFTKDE(kernel='gaussian',bw='silverman').fit(output)
            p_th = kde_th.evaluate(x_opt)

            # Optimal discriminator given G
            D_opt = p_star/(p_th+p_star)

            x_D = torch.linspace(x_opt.min(),x_opt.max(),self.N_test)

            # Build the input to the discriminator 
            input_D = x_D.view(-1,1).to(self.device)
            if self.supervised:
                # If the discriminator is informed with Z, give it zeros for testing 
                input_D = torch.cat((input_D,torch.zeros(self.N_test).view(-1,1)),axis=1)
            if self.CGAN:
                input_D = torch.cat((input_D,C_test_tensor),axis=1)
   
   
        #------------------------------------------------------
        # Select amount of subplots to be shown 
        #------------------------------------------------------
        # Only plot pre-processed data if pre-processing is not None
        # Only plot discriminator confidence if vanilla GAN is used

        single_handle = False # toggle to use if the axis handle is not an array
        if (self.proc_type is None) and (self.supervised):
            fig,ax = plt.subplots(1,1,figsize=(10,10),dpi=100)
            title_string = 'Generator output'
            single_handle = True
        elif (self.proc_type is None) and (not self.supervised):
            fig,ax = plt.subplots(1,2,figsize=(20,10),dpi=100)
            title_string = 'Generator output'
        elif (self.proc_type is not None) and (self.supervised):
            fig,ax = plt.subplots(1,2,figsize=(20,10),dpi=100)
            title_string = 'Post-processed data'
        else:
            fig,ax = plt.subplots(1,3,figsize=(30,10),dpi=100)
            title_string = 'Post-processed data'

        k_ax = 0 # counter for axis index 

        #------------------------------------------------------
        # Plot 1: Post-processed data
        #------------------------------------------------------
        y = self.x
        ymin = y.min()-0.1*np.abs(y.min())
        ymax = y.max()+0.1*np.abs(y.max())

        exact_pdf = self.get_exact_pdf(params)            

        if single_handle:
            ax_plot_1 = ax
        else:
            ax_plot_1 = ax[k_ax]

        ax_plot_1.plot(y,exact_pdf(y),'--k',label='Exact pdf')
        sns.kdeplot(gendata,shade=True,ax=ax_plot_1,label='Generated data')
        ax_plot_1.set_xlabel('$S_t$')
        # fig.suptitle(f'time = {self.T}')
        ax_plot_1.legend()
        # ax_plot_1.set_xlim(xmin=ymin,xmax=ymax)
        ax_plot_1.autoscale(enable=True, axis='x', tight=True)
        ax_plot_1.autoscale(enable=True, axis='y')
        ax_plot_1.set_ylim(bottom=0)        
        ax_plot_1.set_title(title_string)

        # Also plot only the kde plot as pdf 
        f_kde,ax_kde = plt.subplots(1,1,dpi=100)
        ax_kde.plot(y,exact_pdf(y),'--k',label='Exact pdf')
        sns.kdeplot(gendata,shade=True,ax=ax_kde,label='Generated data')
        ax_kde.set_xlabel('$S_t$')
        ax_kde.legend()
        ax_kde.set_xlim(xmin=ymin,xmax=ymax)
        # ax_kde.set_title(title_string)
        f_kde.suptitle(f'Iteration {iteration}')
        f_kde.savefig(os.path.join(self.results_path,'kde_output_iter_%02d'%iteration+'.pdf'),format='pdf')
        plt.close(f_kde)

        #------------------------------------------------------
        # Plot 2: Generator output
        #------------------------------------------------------
        if self.proc_type is not None:
            k_ax += 1
            sns.kdeplot(exact_raw,linestyle='--',color='k',ax=ax[k_ax],label='Pre-processed exact')        
            sns.kdeplot(output,shade=True,ax=ax[k_ax],label='Generated data')
            ax[k_ax].set_xlabel('$R_t$')
            ax[k_ax].legend()
            # ax[k_ax].set_xlim(xmin=a,xmax=b)
            ax[k_ax].autoscale(enable=True, axis='x', tight=True)
            ax[k_ax].autoscale(enable=True, axis='y')
            ax[k_ax].set_ylim(bottom=0)            
            ax[k_ax].set_title('Generator output')

        #------------------------------------------------------
        # Plot 3: Discriminator confidence
        #------------------------------------------------------    

        if not self.supervised:
            k_ax += 1
            ax[k_ax].plot(x_D,D(input_D).view(-1,1).detach().view(-1).cpu().numpy(),label='Discriminator output')
            ax[k_ax].plot(x_opt,D_opt,'--k',label='Optimal discriminator')

            # ax[1].set_title('Discriminator confidence')
            if self.proc_type is None:
                ax[k_ax].set_xlabel('$S_t$')
            else:
                ax[k_ax].set_xlabel('$R_t$')
            ax[k_ax].legend()
            # ax[k_ax].set_xlim(xmin=a,xmax=b)

            ax[k_ax].autoscale(enable=True, axis='x', tight=True)
            ax[k_ax].autoscale(enable=True, axis='y')
            ax[k_ax].set_ylim(bottom=0)

            if save_conf:
            # Repeat plot to save discriminator confidence itself as well 
                f_conf,ax_conf = plt.subplots(1,1,dpi=100)
                ax_conf.plot(x_D,D(input_D).view(-1,1).detach().view(-1).cpu().numpy(),label='Discriminator output')
                ax_conf.plot(x_opt,D_opt,'--k',label='Optimal discriminator')

                if self.proc_type is None:
                    ax_conf.set_xlabel('$S_t$')
                else:
                    ax_conf.set_xlabel('$R_t$')
                ax_conf.legend()
                ax_conf.set_xlim(xmin=a,xmax=b)
                f_conf.suptitle(f'Iteration {iteration}')
                f_conf.savefig(os.path.join(self.results_path,'D_conf_iter_%02d'%iteration+'.pdf'),format='pdf')
                plt.close(f_conf)

        #------------------------------------------------------
        # Wrap up
        #------------------------------------------------------

        fig.suptitle(f'Iteration {iteration}')
        fig.savefig(filename,format=self.format)
        plt.close()
Пример #4
0
    def kde_plot(self,C=None,G=None,params=None,save=False,filename=None,raw_output=False,save_format=None,x_lims=None):
        '''
        Plotting function that creates [kde plot]. 
        '''

        if params is None:
            params = self.params.copy()
        if G is None:
            G = self.G
        if self.CGAN:
            params = {**params,**self.C_test}
            assert self.CGAN and (G.c_dim > 0), 'Generator appears not to be trained as a CGAN, while CGAN is toggled.'
            assert len(C) == 1, 'Only one plot condition is supported. To fix other parameters, specify them in self.C_test'
        if save_format is None:
            save_format = self.format
        if save:
            assert filename is not None, 'Please specify a filename to save the figure to.'
            
        # Define plotting linestyles 
        lines = ['dashed','solid','dotted','dashdot']
        # Initialise handles for legend entries 
        handles_exact = []
        handles = []
        names = []
        fig,ax = plt.subplots(1,1,dpi=100)
        
        if G.c_dim > 0:
            #------------------------------------------------
            # Case 1: Conditional GAN 
            #------------------------------------------------
            # Get name of conditional parameter
            c_name = next(iter(C.keys()))
            # Get array with plot values
            cs = np.array(next(iter(C.values())))

            for i,c in enumerate(cs):
                # Cycle between linestyles
                line = lines[i%len(lines)]
                # First cast the current condition back to a dict
                c_dict = dict()
                c_dict[c_name] = c
                # Cast current condition into tensor, replacing the relevant value in C_test
                c_tensor = make_test_tensor({**self.C_test,**c_dict},self.N_test,device=self.device)
                # Get an input sample 
                in_sample = input_sample(self.N_test,C=c_tensor,device=self.device)
                # Infer with generator
                output = G(in_sample).detach()
                gendata = postprocess(output,{**params,**c_dict}['S0'],proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()


                if raw_output and (self.proc_type is not None):
                    # Get pre-processed exact variates for estimate of exact pdf
                    exact_raw = preprocess(self.sample_exact(N=self.N_test,params={**params,**c_dict}),torch.tensor({**params,**c_dict}['S0'],\
                        dtype=torch.float32).view(-1,1),proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()

                    x_out_kde_exact, kde_exact = FFTKDE(kernel='gaussian',bw='silverman').fit(exact_raw)(self.N_test)
                    x_out_kde_gen, kde_gen = FFTKDE(kernel='gaussian',bw='silverman').fit(output.view(-1).cpu().numpy())(self.N_test)

                    l_e, = ax.plot(x_out_kde_exact,kde_exact,'k',linestyle=line)
                    l_g, = ax.plot(x_out_kde_gen,kde_gen,'-')

                    # Define the color if only one condition is chosen 
                    if (len(cs) == 1):
                        l_g.set_color('darkorange')

                    ax.fill_between(x_out_kde_gen, y1=kde_gen, alpha=0.25, facecolor=l_g.get_color())
                    plt.autoscale(enable=True, axis='x', tight=True)
                    ax.set_ylim(bottom=0)
                    handles_exact.append(l_e)
                    handles.append(l_g)
                    names.append(c_name+f' = {c}')
                    ax.set_xlabel('$R_t$')
                else:
                    a = 0.1*gendata.min()
                    b = 1.1*gendata.max()
                    if a <= 1e-4:
                        a = self.eps
                    x = np.linspace(a,b,1000)
                    # Instantiate function for pdf of analytical distribution
                    exact_pdf = self.get_exact_pdf({**params,**c_dict})
                    x_out_kde_gen, kde_gen = FFTKDE(kernel='gaussian',bw='silverman').fit(gendata)(self.N_test)

                    l_e, = ax.plot(x,exact_pdf(x),'k',linestyle=line)
                    l_g, = ax.plot(x_out_kde_gen,kde_gen,'-')

                    # Define the color if only one condition is chosen 
                    if (len(cs) == 1):
                        l_g.set_color('darkorange')
                    ax.fill_between(x_out_kde_gen, y1=kde_gen, alpha=0.25, facecolor=l_g.get_color())
                    handles_exact.append(l_e)
                    handles.append(l_g)
                    names.append(c_name+f' = {c}')
                    ax.set_xlabel('$S_t$')

            # Make final handles for legend            
            names.append('Exact')         
            handles.append(tuple(handles_exact))
            ax.legend(handles, names, numpoints=1, handler_map={tuple: HandlerTuple(ndivide=None)})

            # Optional manual setting of horizontal axis limits
            if x_lims is not None:
                ax.set_xlim(x_lims)
            else:
                ax.autoscale(enable=True, axis='x', tight=True)
                ax.autoscale(enable=True, axis='y')
                ax.set_ylim(bottom=0)

        else:
            #------------------------------------------------
            # Case 2: Unconditional GAN 
            #------------------------------------------------    
            in_sample = input_sample(self.N_test,device=self.device)
            # Infer with generator
            output = G(in_sample).detach()
            gendata = postprocess(output,params['S0'],proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()
            
            if raw_output and (self.proc_type is not None):
                exact_raw = preprocess(self.sample_exact(N=self.N_test,params=params),torch.tensor(params['S0'],\
                        dtype=torch.float32).view(-1,1),proc_type=self.proc_type,K=calc_K(params,proc_type=self.proc_type,SDE=self.SDE),S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()
                sns.kdeplot(exact_raw,label='Exact',ax=ax,linestyle='--',color='k')
                sns.kdeplot(output,label='Generated',shade=True,color='darkorange')
                ax.set_xlabel('$R_t$')
            else:    
                a = 0.1*gendata.min()
                b = 1.1*gendata.max()
                if a <= 1e-4:
                    a = self.eps
                x = np.linspace(a,b,1000)
                ecdf_gen_data = smd.ECDF(gendata)
                # Instantiate functions for cdf and pdf of analytical distribution
                exact_pdf = self.get_exact_pdf(params)
                ax.plot(x,exact_pdf(x),'--k',label='Exact')
                sns.kdeplot(gendata,label='Generated',ax=ax,shade=True,color='DarkOrange')
                ax.set_xlabel('$S_t$')
            
            ax.legend()
            # Optional manual setting of horizontal axis limits
            if x_lims is not None:
                ax.set_xlim(x_lims)

        # fig.suptitle(c_name + f' = {cs}')
        # fig.suptitle(f'{str( {**self.C_test,**C}) }')

        #------------------------------------------------
        # Wrap up 
        #------------------------------------------------   
        
        if (save == True):
            plt.savefig(filename,format=self.format)
            print(f'Saved kde_plot in folder at {filename}')
            plt.close()
        else:
            plt.show()  
Пример #5
0
    def ecdf_plot(self,C=None,G=None,params=None,save=False,filename=None,raw_output=False,save_format=None,x_plot=None,legendname=None,grid=False):
        '''
        Plotting function that creates [ECDF plot]. 
        '''
        if (not save) and self.save_all_figs:
            save = True
        if params is None:
            params = self.params.copy()
        if G is None:
            G = self.G
        if self.CGAN:
            params = {**params,**self.C_test}
            assert self.CGAN and (G.c_dim > 0), 'Generator appears not to be trained as a CGAN, while CGAN is toggled.'
            assert len(C) == 1, 'Only one plot condition is supported. To fix other parameters, specify them in self.C_test'
        if save_format is None:
            save_format = self.format
        if save:
            assert filename is not None, 'Please specify a filename to save the figure to.'
            
        # Define plotting linestyles 
        lines = ['dashed','solid','dotted','dashdot']
        # Initialise handles for legend entries 
        handles_exact = []
        handles = []
        names = []
        fig,ax = plt.subplots(1,1,dpi=100)
        
        if G.c_dim > 0:
            #------------------------------------------------
            # Case 1: Conditional GAN 
            #------------------------------------------------
            # Get name of the conditional parameter
            c_name = next(iter(C.keys()))
            # Prepare the base strong of legend name 
            if legendname is None:
                legendname = c_name            
            # Get array with plot values
            cs = np.array(next(iter(C.values())))

            for i,c in enumerate(cs):
                # Cycle between linestyles
                line = lines[i%len(lines)]
                # First cast the current condition back to a dict
                c_dict = dict()
                c_dict[c_name] = c
                # Cast current condition into tensor, replacing the relevant value in C_test
                c_tensor = make_test_tensor({**self.C_test,**c_dict},self.N_test,device=self.device)
                # Get an input sample 
                in_sample = input_sample(self.N_test,C=c_tensor,device=self.device)
                # Infer with generator
                output = G(in_sample).detach()
                gendata = postprocess(output,{**params,**c_dict}['S0'],proc_type=self.proc_type,delta_t=torch.tensor(params['t']),S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).view(-1).cpu().numpy()

                if raw_output and (self.proc_type is not None):
                    if x_plot is None:
                        a,b = self.get_plot_bounds(output)
                        x = np.linspace(a,b,1000)
                    else:
                        x = x_plot    
                    # Option to plot output before postprocessing 
                    ecdf_gen_data = smd.ECDF(output.view(-1).cpu().numpy())
                    # Get pre-processed exact variates for estimate of exact cdf 
                    exact_raw = preprocess(self.sample_exact(N=self.N_test,params={**params,**c_dict}),torch.tensor({**params,**c_dict}['S0'],\
                        dtype=torch.float32).view(-1,1),proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).view(-1).cpu().numpy()
                    exact_cdf = smd.ECDF(exact_raw)

                    l_e, = ax.plot(x,exact_cdf(x),'k',linestyle=line)
                    l_g, = ax.plot(x,ecdf_gen_data(x),'-')
                    handles_exact.append(l_e)
                    handles.append(l_g)
                    names.append(legendname+f'= {c}')
                    ax.set_xlabel('$R_t$')
                else:
                    if x_plot is None:
                        a = 0.1*gendata.min()
                        b = 1.1*gendata.max()
                        if a <= 1e-4:
                            a = self.eps
                        x = np.linspace(a,b,1000)
                    else:
                        x = x_plot                    

                    # Plot output after pre-processing 
                    ecdf_gen_data = smd.ECDF(gendata)
                    # Instantiate function for cdf of analytical distribution
                    exact_cdf = self.get_exact_cdf({**params,**c_dict})

                    l_e, = ax.plot(x,exact_cdf(x),'k',linestyle=line)
                    l_g, = ax.plot(x,ecdf_gen_data(x),'-')
                    handles_exact.append(l_e)
                    handles.append(l_g)
                    names.append(legendname+f'={c}')
                    ax.set_xlabel('$S_{t+\Delta t} \mid S_t$')

            # Make final handles for legend            
            names.append('Exact')         
            handles.append(tuple(handles_exact))
            ax.legend(handles, names, numpoints=1, handler_map={tuple: HandlerTuple(ndivide=None)})

        else:
            #------------------------------------------------
            # Case 2: Unconditional GAN 
            #------------------------------------------------    
            in_sample = input_sample(self.N_test,device=self.device)
            # Infer with generator
            output = G(in_sample).detach()
            gendata = postprocess(output,params['S0'],proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).view(-1).cpu().numpy()
            
            if raw_output and (self.proc_type is not None):
                if x_plot is None:
                    a,b = self.get_plot_bounds(output)
                    x = np.linspace(a,b,1000)
                else:
                    x = x_plot
                ecdf_gen_data = smd.ECDF(output.view(-1).cpu().numpy())
                exact_raw = preprocess(self.sample_exact(N=self.N_test,params=params),torch.tensor(params['S0'],\
                        dtype=torch.float32).view(-1,1),proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).view(-1).cpu().numpy()
                ax.plot(x,smd.ECDF(exact_raw)(x),'Exact')
                ax.plot(x,smd.ECDF(output.view(-1).cpu().numpy())(x),'-',label='Generated')
                ax.set_xlabel('$R_t$')
            else:    
                a = 0.1*gendata.min()
                b = 1.1*gendata.max()
                if a <= 1e-4:
                    a = self.eps
                x = np.linspace(a,b,1000)
                ecdf_gen_data = smd.ECDF(gendata)
                # Instantiate functions for cdf and pdf of analytical distribution
                exact_cdf = self.get_exact_cdf(params)
                ax.plot(x,exact_cdf(x),'--k',label='Exact')
                ax.plot(x,ecdf_gen_data(x),label='Generated')
                ax.set_xlabel('$S_{t+\Delta t} \mid S_t$')

            # Optional manual setting of horizontal axis limits
            if x_lims is not None:
                ax.set_xlim(x_lims)
            
            ax.legend()

        # Uncommented for paper 
        # fig.suptitle(c_name + f' = {cs}')
        # fig.suptitle(f'{str( {**self.C_test,**C}) }')

        #------------------------------------------------
        # Wrap up 
        #------------------------------------------------   
        
        if grid:
            plt.grid('on')

        if (save == True):
            plt.savefig(filename,format=self.format)
            print(f'Saved ecdf_plot in folder at {filename}')
            plt.close()
        else:
            plt.show()  
Пример #6
0
def train_GAN(netD, netG, data, meta, netD_Mil=None):
    '''
    Training loop: train_GAN(netD,netG,data,meta)
    Inspired by several tricks from DCGAN PyTorch tutorial. 
    '''

    #-------------------------------------------------------
    # Initialisation
    #-------------------------------------------------------

    real_label = 1
    fake_label = 0
    GANloss = nn.BCELoss()

    GAN_step = step_handler(supervised_bool=meta['supervised'],
                            CGAN_bool=data.CGAN)

    # Initialise lists for logging
    D_losses = []
    G_losses = []
    times = []
    wasses = []
    ksstat = []
    delta_ts_passed = []
    D_grads = []
    G_grads = []

    # Short handle for training params
    c_lr = meta['c_lr']
    cut_lr_every = meta['cut_lr_every']
    epochs = meta['epochs']
    results_path = meta['results_path']
    b_size = meta['batch_size']
    proc_type = meta['proc_type']
    device = meta['device']

    if not pt.exists(pt.join(results_path, 'training', '')):
        os.mkdir(pt.join(results_path, 'training', ''))
    train_analysis = CGANalysis(data,netD,netG,SDE=data.SDE,save_all_figs=meta['save_figs'],results_path=pt.join(meta['results_path'],'training',''),\
        proc_type=proc_type,eps=meta['eps'],supervised=meta['supervised'],device=meta['device'])
    train_analysis.format = 'png'

    if data.C is not None:
        C_tensors = dict_to_tensor(data.C)
        C_test = make_test_tensor(data.C_test, data.N_test)
    else:
        C_test = None

    # Pre-allocate an empty array for each layer to store the norm
    for l in range(count_layers(netD)):
        D_grads.append([])

    for l in range(count_layers(netG)):
        G_grads.append([])

    # Initialise counters
    itervec = []
    iters = 0
    plot_iter = 0

    # Get the amount of batches implied by training set size and batch_size
    n_batches = data.N // b_size

    optG = optim.Adam(netG.parameters(),
                      lr=meta['lr_G'],
                      betas=(meta['beta1'], meta['beta2']))
    optD = optim.Adam(netD.parameters(),
                      lr=meta['lr_D'],
                      betas=(meta['beta1'], meta['beta2']))

    #-------------------------------------------------------
    # Start training loop
    #-------------------------------------------------------

    for epoch in range(epochs):
        tick0 = time.time()
        for i in range(n_batches):
            if iters % cut_lr_every == 0:
                optG.param_groups[0]['lr'] = optG.param_groups[0]['lr'] / c_lr

            # Sample random minibatch from training set with replacement
            indices = np.array((np.random.rand(b_size) * data.N), dtype=int)
            # Uncomment to sample minibatch from training set without replacement
            # indices = np.arange(i*b_size,(i+1)*b_size)

            # Get data batch based on indices
            X_i = data.exact[indices, :].to(device)
            C_i = C_tensors[indices, :].to(device) if data.CGAN else None
            Z_i = data.Z[indices, :].to(device) if meta['supervised'] else None

            #-------------------------------------------------------
            # GAN training step
            #-------------------------------------------------------

            errD, errG, D_x, D_G_z1, D_G_z2 = GAN_step(netD,
                                                       netG,
                                                       optD,
                                                       optG,
                                                       GANloss,
                                                       device,
                                                       X_i=X_i,
                                                       C_i=C_i,
                                                       Z_i=Z_i,
                                                       real_label=real_label,
                                                       fake_label=fake_label,
                                                       n_D=meta['n_D'])

            #-------------------------------------------------------

            # Output training stats
            if (iters % 100 == 0) and (i % b_size) == 0:
                print(
                    '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, epochs, i, data.N // b_size, errD.item(),
                       errG.item(), D_x, D_G_z1, D_G_z2))
                input_G = input_sample(data.N_test, C=C_test, device=device)
                fake = postprocess(netG(input_G).detach().view(-1),{**data.params,**data.C_test}['S0'],proc_type=proc_type,\
                    S_ref=torch.tensor(data.params['S_bar'],device=meta['device'],dtype=torch.float32),eps=meta['eps']).cpu().view(-1).numpy()
                ecdf_fake = smd.ECDF(fake)
                ecdf_test = smd.ECDF(data.exact_test.view(-1))
                # cdf_test = train_analysis.exact_cdf(params={**data.params,**C_test})
                x = np.linspace(1e-5, 3, 1000)  # plotting vector
                x = train_analysis.x

                # Infinity norm with ECDF on test data (Kolmogorov-Smirnov statistic)
                ksstat.append(np.max(np.abs(ecdf_fake(x) - ecdf_test(x))))
                # ksstat.append(stat.kstest(fake,cdf=cdf_test),alternative='two-sided')[0])

                # 1D Wasserstein distance as implemented in Scipy stats package
                wasses.append(
                    stat.wasserstein_distance(fake, data.exact_test.view(-1)))

                # Keep track of the L1 norm of the gradients in each layer
                append_gradients(netD, netG, D_grads, G_grads)

                itervec.append(iters)

            # Update the generated data in analysis instance
            if ((iters % meta['plot_interval'] == 0)
                    and (meta['save_iter_plot'] == True)):
                # Update network references for inference
                train_analysis.G = netG
                train_analysis.D = netD
                train_analysis.save_iter_plot(iters,
                                              params=data.params,
                                              D=netD,
                                              G=netG)
                plot_iter += 1

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            iters += 1

        tick1 = time.time()
        times.append(tick1 - tick0)

    #-------------------------------------------------------
    # Store training results
    #-------------------------------------------------------

    # Get range of training parameters if CGAN
    if data.C is not None:
        C_ranges = dict()
        for key in list(data.C.keys()):
            C_ranges[key] = (min(data.C[key]), max(data.C[key]))

    # Create dict for output log
    output_dict = dict(
        iterations=np.arange(1, iters + 1),
        iters=itervec,
        D_losses=D_losses,
        G_losses=G_losses,
        Wass_dist_test=wasses,
        KS_stat_test=ksstat,
        D_layers=[count_layers(netD)],
        G_layers=[count_layers(netG)],
        final_lr_G=[optG.param_groups[0]['lr']],
        final_lr_D=[optD.param_groups[0]['lr']],
        total_time=[np.sum(times)],
        train_condition=list(data.C.keys()) if data.C is not None else None,
        train_condition_ranges=[str(C_ranges)] if data.C is not None else None,
        test_condition=[str(data.C_test)] if data.C is not None else None,
        params_names=list(data.params),
        params=list(data.params.values()),
        SDE=[data.SDE],
    )

    # Add metaparameters to output log
    output_dict = {**meta, **output_dict}

    for k in range(len(G_grads)):
        dict_entry = dict()
        dict_entry['G_grad_layer_%d' % k] = G_grads[k]
        output_dict.update(dict_entry)

    for k in range(len(D_grads)):
        dict_entry = dict()
        dict_entry['D_grad_layer_%d' % k] = D_grads[k]
        output_dict.update(dict_entry)

    # Convert to Pandas DataFrame
    results_df = pd.DataFrame(
        dict([(k, pd.Series(v)) for k, v in output_dict.items()]))

    pd.concat(
        (results_df, pd.DataFrame({
            **data.params,
            **data.C_test
        }, index=[0])),
        axis=1,
        ignore_index=True)

    results_df = pd.concat(
        [results_df,
         pd.DataFrame(G_grads[k], columns=['GradsG_L%d' % k])],
        axis=1,
        sort=False)
    for k in range(len(D_grads)):
        results_df = pd.concat(
            [results_df,
             pd.DataFrame(D_grads[k], columns=['GradsD_L%d' % k])],
            axis=1,
            sort=False)

    print('----- Training complete -----')

    return output_dict, results_df