Exemple #1
0
def model(dim_x,
          dim_y,
          dim_z,
          dim_total,
          lr,
          l2_reg,
          meta_epoch,
          gamma,
          hidden_depth=8):
    nodes = []
    # 定义输入层节点
    nodes.append(InputNode(dim_total, name='input'))

    # 定义隐藏层节点
    for k in range(hidden_depth):
        nodes.append(
            Node(nodes[-1],
                 GLOWCouplingBlock, {
                     'subnet_constructor': F_fully_connected,
                     'clamp': 2.0,
                 },
                 name='coupling_{k}'))
        nodes.append(
            Node(nodes[-1], PermuteRandom, {'seed': 1}, name='permute_{k}'))
    nodes.append(
        Node(nodes[-1],
             GLOWCouplingBlock, {
                 'subnet_constructor': F_fully_connected,
                 'clamp': 2.0,
             },
             name='coupling_{k}'))
    # 定义输出层节点
    nodes.append(OutputNode(nodes[-1], name='output'))

    # 构建可逆网络
    inn = ReversibleGraphNet(nodes)

    # 定义优化器
    # TODO:参数调整
    optimizer = torch.optim.Adam(inn.parameters(),
                                 lr=lr,
                                 betas=(0.9, 0.999),
                                 eps=1e-04,
                                 weight_decay=l2_reg)

    # 学习率调整
    # TODO:参数调整
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=meta_epoch,
                                                gamma=gamma)

    # 损失函数设置
    # x,z无监督:MMD,y有监督:平方误差
    loss_backward = MMD_multiscale
    loss_latent = MMD_multiscale
    loss_fit = fit

    return inn, optimizer, scheduler, loss_backward, loss_latent, loss_fit
Exemple #2
0
def main():

    # Set up simulation parameters
    batch_size = 128  # set batch size
    r = 3  # the grid dimension for the output tests
    test_split = r * r  # number of testing samples to use
    sig_model = 'sg'  # the signal model to use
    sigma = 0.2  # the noise std
    ndata = 128  #32 number of data samples in time series
    bound = [0.0, 1.0, 0.0, 1.0]  # effective bound for likelihood
    seed = 1  # seed for generating data
    out_dir = "/home/hunter.gabbard/public_html/CBC/cINNamon/gausian_results/"
    n_neurons = 0
    do_contours = True  # if True, plot contours of predictions by INN
    plot_cadence = 50
    do_latent_struc = False  # if True, plot latent space 2D structure
    conv_nn = False  # if True, use convolutional nn structure

    # setup output directory - if it does not exist
    os.system('mkdir -p %s' % out_dir)

    # generate data
    pos, labels, x, sig = data.generate(
        model=sig_model,
        tot_dataset_size=int(1e6),  # 1e6
        ndata=ndata,
        sigma=sigma,
        prior_bound=bound,
        seed=seed)

    if do_latent_struc:
        # calculate mode of x-space for both pars
        mode_1 = stats.mode(np.array(pos[:, 0]))
        mode_2 = stats.mode(np.array(pos[:, 1]))

    # seperate the test data for plotting
    pos_test = pos[-test_split:]
    labels_test = labels[-test_split:]
    sig_test = sig[-test_split:]

    # plot the test data examples
    plt.figure(figsize=(6, 6))
    fig_post, axes = plt.subplots(r, r, figsize=(6, 6))
    cnt = 0
    for i in range(r):
        for j in range(r):
            axes[i, j].plot(x, np.array(labels_test[cnt, :]), '.')
            axes[i, j].plot(x, np.array(sig_test[cnt, :]), '-')
            cnt += 1
            axes[i, j].axis([0, 1, -1.5, 1.5])
    plt.savefig("%stest_distribution.png" % out_dir, dpi=360)
    plt.close()

    # setting up the model
    ndim_x = 2  # number of posterior parameter dimensions (x,y)
    ndim_y = ndata  # number of label dimensions (noisy data samples)
    ndim_z = 200  # number of latent space dimensions?
    ndim_tot = max(
        ndim_x,
        ndim_y + ndim_z) + n_neurons  # must be > ndim_x and > ndim_y + ndim_z

    # define different parts of the network
    # define input node
    inp = InputNode(ndim_tot, name='input')

    # define hidden layer nodes
    filtsize = 3
    dropout = 0.0
    clamp = 1.0
    if conv_nn == True:
        t1 = Node(
            [inp.out0], rev_multiplicative_layer, {
                'F_class': F_conv,
                'clamp': clamp,
                'F_args': {
                    'kernel_size': filtsize,
                    'leaky_slope': 0.1,
                    'batch_norm': False
                }
            })

        t2 = Node(
            [t1.out0], rev_multiplicative_layer, {
                'F_class': F_conv,
                'clamp': clamp,
                'F_args': {
                    'kernel_size': filtsize,
                    'leaky_slope': 0.1,
                    'batch_norm': False
                }
            })

        t3 = Node(
            [t2.out0], rev_multiplicative_layer, {
                'F_class': F_conv,
                'clamp': clamp,
                'F_args': {
                    'kernel_size': filtsize,
                    'leaky_slope': 0.1,
                    'batch_norm': False
                }
            })
        #t4 = Node([t1.out0], rev_multiplicative_layer,
        #          {'F_class': F_conv, 'clamp': 2.0,
        #           'F_args':{'kernel_size': filtsize,'leaky_slope':0.1,
        #           'batch_norm':False}})

        #t5 = Node([t2.out0], rev_multiplicative_layer,
        #          {'F_class': F_conv, 'clamp': 2.0,
        #           'F_args':{'kernel_size': filtsize,'leaky_slope':0.1,
        #           'batch_norm':False}})

    else:
        t1 = Node(
            [inp.out0], rev_multiplicative_layer, {
                'F_class': F_fully_connected,
                'clamp': clamp,
                'F_args': {
                    'dropout': dropout
                }
            })

        t2 = Node(
            [t1.out0], rev_multiplicative_layer, {
                'F_class': F_fully_connected,
                'clamp': clamp,
                'F_args': {
                    'dropout': dropout
                }
            })

        t3 = Node(
            [t2.out0], rev_multiplicative_layer, {
                'F_class': F_fully_connected,
                'clamp': clamp,
                'F_args': {
                    'dropout': dropout
                }
            })

    # define output layer node
    outp = OutputNode([t3.out0], name='output')

    nodes = [inp, t1, t2, t3, outp]
    model = ReversibleGraphNet(nodes)

    # Train model
    # Training parameters
    n_epochs = 12000
    meta_epoch = 12  # what is this???
    n_its_per_epoch = 12

    lr = 1e-2
    gamma = 0.01**(1. / 120)
    l2_reg = 2e-5

    y_noise_scale = 3e-2
    zeros_noise_scale = 3e-2

    # relative weighting of losses:
    lambd_predict = 4000.  #300 forward pass
    lambd_latent = 900.  #300 laten space
    lambd_rev = 1000.  #400 backwards pass

    # padding both the data and the latent space
    # such that they have equal dimension to the parameter space
    #pad_x = torch.zeros(batch_size, ndim_tot - ndim_x)
    #pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=(0.8, 0.8),
                                 eps=1e-04,
                                 weight_decay=l2_reg)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=meta_epoch,
                                                gamma=gamma)

    # define the three loss functions
    loss_backward = MMD_multiscale
    loss_latent = MMD_multiscale
    loss_fit = fit

    # set up training set data loader
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(
        pos[test_split:], labels[test_split:]),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)

    # initialisation of network weights
    for mod_list in model.children():
        for block in mod_list.children():
            for coeff in block.children():
                if conv_nn == True:
                    coeff.conv3.weight.data = 0.01 * torch.randn(
                        coeff.conv3.weight.shape)
    model.to(device)

    # number of test samples to use after training
    N_samp = 2500

    # precompute true likelihood on the test data
    Ngrid = 64
    cnt = 0
    lik = np.zeros((r, r, Ngrid * Ngrid))
    true_post = np.zeros((r, r, N_samp, 2))
    lossf_hist = []
    lossrev_hist = []
    losstot_hist = []
    losslatent_hist = []
    beta_score_hist = []

    for i in range(r):
        for j in range(r):
            mvec, cvec, temp, post_points = data.get_lik(np.array(
                labels_test[cnt, :]).flatten(),
                                                         n_grid=Ngrid,
                                                         sig_model=sig_model,
                                                         sigma=sigma,
                                                         xvec=x,
                                                         bound=bound)
            lik[i, j, :] = temp.flatten()
            true_post[i, j, :] = post_points[:N_samp]
            cnt += 1

    # start training loop
    try:
        t_start = time()
        # loop over number of epochs
        for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80):

            scheduler.step()

            # Initially, the l2 reg. on x and z can give huge gradients, set
            # the lr lower for this
            if i_epoch < 0:
                print('inside this iepoch<0 thing')
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr * 1e-2

            # train the model
            losstot, losslatent, lossrev, lossf, lambd_latent = train(
                model, train_loader, n_its_per_epoch, zeros_noise_scale,
                batch_size, ndim_tot, ndim_x, ndim_y, ndim_z, y_noise_scale,
                optimizer, lambd_predict, loss_fit, lambd_latent, loss_latent,
                lambd_rev, loss_backward, conv_nn, i_epoch)

            # append current loss value to loss histories
            lossf_hist.append(lossf.data.item())
            lossrev_hist.append(lossrev.data.item())
            losstot_hist.append(losstot)
            losslatent_hist.append(losslatent.data.item())
            pe_losses = [
                losstot_hist, losslatent_hist, lossrev_hist, lossf_hist
            ]

            # loop over a few cases and plot results in a grid
            cnt = 0
            beta_max = 0
            if ((i_epoch % plot_cadence == 0) & (i_epoch > 0)):
                # use the network to predict parameters\

                if do_latent_struc:
                    # do latent space structure plotting
                    y_samps_latent = np.tile(np.array(labels_test[0, :]),
                                             1).reshape(1, ndim_y)
                    y_samps_latent = torch.tensor(y_samps_latent,
                                                  dtype=torch.float)
                    x1_i_dist = []
                    x2_i_dist = []
                    x1_i_par = np.array([])
                    x2_i_par = np.array([])

                    # define latent space mesh grid
                    z_mesh = np.mgrid[-0.99:-0.01:100j, -0.99:-0.01:100j]
                    z_mesh = np.vstack([z_mesh, np.zeros((2, 100, 100))])

                    #for z_i in range(10000):
                    for i in range(z_mesh.shape[1]):
                        for j in range(z_mesh.shape[2]):
                            a = torch.randn(1, ndim_z)
                            a[0, 0] = z_mesh[0, i, j]
                            a[0, 1] = z_mesh[1, i, j]
                            x_i = model(torch.cat([
                                a,
                                torch.zeros(1, ndim_tot - ndim_y - ndim_z),
                                y_samps_latent
                            ],
                                                  dim=1).to(device),
                                        rev=True)
                            x_i = x_i.cpu().data.numpy()

                            # calculate hue and intensity
                            if np.abs(mode_1[0][0] -
                                      x_i[0][0]) < np.abs(mode_2[0][0] -
                                                          x_i[0][1]):
                                z_mesh[2, i,
                                       j] = np.abs(mode_1[0][0] - x_i[0][0])
                                z_mesh[3, i, j] = 0

                            else:
                                z_mesh[2, i,
                                       j] = np.abs(mode_2[0][0] - x_i[0][1])
                                z_mesh[3, i, j] = 1

                    z_mesh[2, :, :][z_mesh[3, :, :] == 0] = z_mesh[2, :, :][
                        z_mesh[3, :, :] == 0] / np.max(
                            z_mesh[2, :, :][z_mesh[3, :, :] == 0])
                    z_mesh[2, :, :][z_mesh[3, :, :] == 1] = z_mesh[2, :, :][
                        z_mesh[3, :, :] == 1] / np.max(
                            z_mesh[2, :, :][z_mesh[3, :, :] == 1])

                    bg_color = 'black'
                    fg_color = 'red'

                    fig = plt.figure(facecolor=bg_color, edgecolor=fg_color)
                    axes = fig.add_subplot(111)
                    axes.patch.set_facecolor(bg_color)
                    axes.xaxis.set_tick_params(color=fg_color,
                                               labelcolor=fg_color)
                    axes.yaxis.set_tick_params(color=fg_color,
                                               labelcolor=fg_color)
                    for spine in axes.spines.values():
                        spine.set_color(fg_color)
                    plt.scatter(z_mesh[0, :, :][z_mesh[3, :, :] == 0],
                                z_mesh[1, :, :][z_mesh[3, :, :] == 0],
                                s=1,
                                c=z_mesh[2, :, :][z_mesh[3, :, :] == 0],
                                cmap='Greens',
                                axes=axes)
                    plt.scatter(z_mesh[0, :, :][z_mesh[3, :, :] == 1],
                                z_mesh[1, :, :][z_mesh[3, :, :] == 1],
                                s=1,
                                c=z_mesh[2, :, :][z_mesh[3, :, :] == 1],
                                cmap='Purples',
                                axes=axes)
                    plt.xlabel('z-space', color=fg_color)
                    plt.ylabel('z-space', color=fg_color)
                    plt.savefig('%sstruct_z.png' % out_dir, dpi=360)
                    plt.close()

                # end of latent space structure plotting

                # initialize plot for showing testing results
                fig, axes = plt.subplots(r, r, figsize=(6, 6))
                for i in range(r):
                    for j in range(r):

                        # convert data into correct format
                        y_samps = np.tile(np.array(labels_test[cnt, :]),
                                          N_samp).reshape(N_samp, ndim_y)
                        y_samps = torch.tensor(y_samps, dtype=torch.float)
                        #y_samps += y_noise_scale * torch.randn(N_samp, ndim_y)
                        y_samps = torch.cat(
                            [
                                torch.randn(N_samp,
                                            ndim_z),  #zeros_noise_scale * 
                                torch.zeros(N_samp,
                                            ndim_tot - ndim_y - ndim_z),
                                y_samps
                            ],
                            dim=1)
                        y_samps = y_samps.to(device)

                        if conv_nn == True:
                            y_samps = y_samps.reshape(y_samps.shape[0],
                                                      y_samps.shape[1], 1, 1)
                        rev_x = model(y_samps, rev=True)
                        rev_x = rev_x.cpu().data.numpy()

                        if conv_nn == True:
                            rev_x = rev_x.reshape(rev_x.shape[0],
                                                  rev_x.shape[1])

                        # plot the samples and the true contours
                        axes[i, j].clear()
                        axes[i, j].contour(mvec,
                                           cvec,
                                           lik[i, j, :].reshape(Ngrid, Ngrid),
                                           levels=[0.68, 0.9, 0.99])
                        axes[i, j].scatter(rev_x[:, 0],
                                           rev_x[:, 1],
                                           s=0.5,
                                           alpha=0.5,
                                           color='red')
                        axes[i, j].scatter(true_post[i, j, :, 1],
                                           true_post[i, j, :, 0],
                                           s=0.5,
                                           alpha=0.5,
                                           color='blue')
                        axes[i, j].plot(pos_test[cnt, 0],
                                        pos_test[cnt, 1],
                                        '+r',
                                        markersize=8)
                        axes[i, j].axis(bound)

                        # add contours to results
                        try:
                            if do_contours:
                                contour_y = np.reshape(rev_x[:, 1],
                                                       (rev_x[:, 1].shape[0]))
                                contour_x = np.reshape(rev_x[:, 0],
                                                       (rev_x[:, 0].shape[0]))
                                contour_dataset = np.array(
                                    [contour_x, contour_y])
                                kernel_cnn = make_contour_plot(
                                    axes[i, j],
                                    contour_x,
                                    contour_y,
                                    contour_dataset,
                                    'red',
                                    flip=False,
                                    kernel_cnn=False)

                                # run overlap tests on results
                                contour_x = np.reshape(
                                    true_post[i, j][:, 1],
                                    (true_post[i, j][:, 1].shape[0]))
                                contour_y = np.reshape(
                                    true_post[i, j][:, 0],
                                    (true_post[i, j][:, 0].shape[0]))
                                contour_dataset = np.array(
                                    [contour_x, contour_y])
                                ks_score, ad_score, beta_score = overlap_tests(
                                    rev_x, true_post[i, j], pos_test[cnt],
                                    kernel_cnn, gaussian_kde(contour_dataset))
                                axes[i, j].legend([
                                    'Overlap: %s' %
                                    str(np.round(beta_score, 3))
                                ])

                                beta_score_hist.append([beta_score])
                        except ValueError as e:
                            pass

                        cnt += 1

                # sve the results to file
                fig_post.canvas.draw()
                plt.savefig('%sposteriors_%s.png' % (out_dir, i_epoch),
                            dpi=360)
                plt.savefig('%slatest.png' % out_dir, dpi=360)

                plot_losses(pe_losses,
                            '%spe_losses.png' % out_dir,
                            legend=['PE-GEN'])
                plot_losses(pe_losses,
                            '%spe_losses_logscale.png' % out_dir,
                            logscale=True,
                            legend=['PE-GEN'])

    except KeyboardInterrupt:
        pass
    finally:
        print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")
Exemple #3
0
def main():
    # Set up data
    test_split = 1 # number of testing samples to use

    # load in gw templates and signals
    signal_train_images, signal_train_pars, signal_image, noise_signal, signal_pars = load_gw_data()
    
    if add_noise_real:
        train_array = []
        train_pe_array = []
        for i in range(len(signal_train_images)):
            for j in range(n_real):
                train_array.append([signal_train_images[i] + np.random.normal(loc=0.0, scale=n_sig) / 817.98 * 1079.23])
                train_pe_array.append([signal_train_pars[i]])
        train_array = np.array(train_array)
        train_pe_array = np.array(train_pe_array)
        train_array = train_array.reshape(train_array.shape[0],train_array.shape[2])
        train_pe_array = train_pe_array.reshape(train_pe_array.shape[0],train_pe_array.shape[2])
    else:
        for i in range(len(signal_train_images)):
            signal_train_images[i] += np.random.normal(loc=0.0, scale=n_sig) / 817.98 * 1079.23
    
    # load in lalinference noise signal
    noise_signal = h5py.File("gw_data/data/%s0%s.hdf5" % (event_name,tag),"r")
    noise_signal = np.reshape(noise_signal['wht_wvf'][:] * 1079.23,(n_pix,1)) # 817.98 need to not have this hardcoded
    #noise_signal *= 1079.23 / 817.98
    #noise_signal = noise_signal.reshape(noise_signal.shape[0],1)

    plt.plot(noise_signal)
    plt.savefig('%s/test.png' % out_path)
    plt.close()

    # load in lalinference samples
    with open('gw_data/data/gw150914_mc_q_lalinf_post_srate-1024_python3.sav','rb' ) as f:
        lalinf_post = pickle.load(f) 
    lalinf_mc = lalinf_post[0]
    lalinf_q = lalinf_post[1]
    kernel_lalinf = gaussian_kde(lalinf_post)

    # declare gw variants of positions and labels
    mc_max = np.max(signal_train_pars[:,0])
    #signal_train_pars /= mc_max
    labels = torch.tensor(signal_train_images, dtype=torch.float)
    pos = torch.tensor(signal_train_pars, dtype=torch.float)

    # setting up the model
    ndim_x = 2    # number of parameter dimensions
    ndim_y = n_pix    # number of data dimensions
    ndim_z = 100    # number of latent space dimensions?
    ndim_tot = n_pix+ndim_z+ndim_x+n_neurons  # two times the number data dimensions?

    # define different parts of the network
    # define input node
    inp = InputNode(ndim_tot, name='input')

    # define hidden layer nodes
    # number of nodes equal to number of parameters?
    t1 = Node([inp.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}, 'F_args': {'batch_norm': False}})
    
    
    t2 = Node([t1.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}, 'F_args': {'batch_norm': False}})
     
    """
    t3 = Node([t2.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}, 'F_args': {'batch_norm': False}})
     
    
    t4 = Node([t3.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}})
    
    t5 = Node([t4.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}})
    
    t6 = Node([t5.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}})
    
    t7 = Node([t6.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}})

    t8 = Node([t7.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}})

    t9 = Node([t8.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}})
    
    t10 = Node([t9.out0], rev_multiplicative_layer,
              {'F_class': F_fully_connected, 'clamp': 2.0,
               'F_args': {'dropout': 0.0}})
    """
    # define output layer node
    outp = OutputNode([t2.out0], name='output')

    nodes = [inp, t1, t2, outp]
    model = ReversibleGraphNet(nodes)

    # Train model

    lr = 1e-4
    gamma = 0.01**(1./120)
    l2_reg = 2e-5

    y_noise_scale = 1     # amount of noise to add to y parameter?
    zeros_noise_scale = 3e-2 # what is this??

    # relative weighting of losses:
    lambd_predict = 300. # forward pass
    lambd_latent = 300.  # laten space
    lambd_rev = 400.     # backwards pass

    # padding both the data and the latent space
    # such that they have equal dimension to the parameter space
    pad_x = torch.zeros(batch_size, ndim_tot - ndim_x)
    pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z)


    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.8, 0.8),
                             eps=1e-04, weight_decay=l2_reg, amsgrad=True)
    #optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=meta_epoch,
                                                gamma=gamma)


    # define the three loss functions
    loss_backward = MMD_multiscale
    loss_latent = MMD_multiscale
    loss_fit = fit

    # set up test set data loader
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(pos[:test_split], labels[:test_split]),
        batch_size=batch_size, shuffle=True, drop_last=True)

    # set up training set data loader
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(pos[:], labels[:]),
        batch_size=batch_size, shuffle=True, drop_last=True)


    # what is happening here? More set up of network?
    for mod_list in model.children():
        for block in mod_list.children():
            for coeff in block.children():
                coeff.fc3.weight.data = 0.01*torch.randn(coeff.fc3.weight.shape)
            
    model.to(device)

    # number of test samples to use after training 
    N_samp = 4000

    # choose test samples to use after training
    # 1000 iterations of test signal burried in noise. Only need to change z parameter.
    #x_samps = torch.cat([x for x,y in test_loader], dim=0)[:N_samp]
    #y_samps = torch.cat([y for x,y in test_loader], dim=0)[:N_samp]
    #y_samps += torch.randn(N_samp, ndim_y) #* y_noise_scale
    y_samps = y_noise_scale * np.transpose(torch.tensor(np.repeat(noise_signal, N_samp, axis=1), dtype=torch.float))

    # make test samples. First element is the latent space dimension
    # second element is the extra zeros needed to pad the input.
    # the third element is the time series
    y_samps = torch.cat([torch.randn(N_samp, ndim_z),
                         zeros_noise_scale * torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z), # zeros_noise_scale * 
                         y_samps], dim=1)
    # what we should have now are 1000 copies of the event burried in noise with zero padding up to 2048
    y_samps = y_samps.to(device)

    # start training loop
    lossf_hist = []
    lossrev_hist = []
    beta_score_hist = []      
    kernel_cnn = False      
    try:
    #     print('#Epoch \tIt/s \tl_total')
        t_start = time()
        # loop over number of epochs
        for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80):

            scheduler.step()

            # Initially, the l2 reg. on x and z can give huge gradients, set
            # the lr lower for this
            if i_epoch < 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr * 1e-2

            #print(i_epoch, end='\t ')
            _,lossf, lossrev = train(model,train_loader,n_its_per_epoch,zeros_noise_scale,batch_size,ndim_tot,ndim_x,ndim_y,ndim_z,y_noise_scale,optimizer,lambd_predict,loss_fit,lambd_latent,loss_latent,lambd_rev,loss_backward,i_epoch)

            # append current loss value to loss histories
            lossf_hist.append(lossf.item())
            lossrev_hist.append(lossrev.item())
            pe_losses = [lossf_hist,lossrev_hist]

            # predict parameters of signal
            rev_x = model(y_samps, rev=True)
            rev_x = rev_x.cpu().data.numpy()
            #rev_x[:,0] = mc_max * rev_x[:,0]
        
            # plot pe results and loss
            beta_max = 0
            """
            if i_epoch>0:
                kernel_cnn = gaussian_kde(rev_x)
                #overlap_y = np.reshape(rev_x[:,1], (rev_x[:,1].shape[0]))
                #overlap_x = np.reshape(rev_x[:,0], (rev_x[:,0].shape[0]))
                #overlap_dataset = np.array([overlap_x,overlap_y]).transpose()
                ks_score, ad_score, beta_score = overlap_tests(rev_x,lalinf_post,signal_pars,kernel_cnn,kernel_lalinf)
                beta_score_hist.append([beta_score])    
                plt.plot(np.linspace(1,i_epoch,len(beta_score_hist)),beta_score_hist)
                plt.savefig('%s/latest/beta_hist.png' % out_path)
                plt.close()            
            """
            if ((i_epoch % plot_cadence == 0) & (i_epoch>0)):
                pe_std = [0.02185649964844209, 0.005701401364171313] # this will need to be removed
                beta_score_hist.append([plot_pe_samples(rev_x,signal_pars,out_path,i_epoch,lalinf_post,pe_std,kernel_lalinf=kernel_lalinf,kernel_cnn=kernel_cnn)])
                plt.plot(np.linspace(plot_cadence,i_epoch,len(beta_score_hist)),beta_score_hist)
                plt.savefig('%s/latest/beta_hist.png' % out_path)
                plt.close()

                # plot loss curves - non-log and log
                plot_losses(pe_losses,'%s/latest/pe_losses.png' % out_path,legend=['PE-GEN'])
                plot_losses(pe_losses,'%s/latest/pe_losses_logscale.png' % out_path,logscale=True,legend=['PE-GEN'])

                # save model
                #if beta_score_hist[:-1] > beta_max: beta_max = beta_score_hist[:-1]
                #if beta_score_hist[:-1] > beta_max or i_epoch==plot_cadence: model.save_state_dict('mytraining.pt')
            # make PE scatter plots with contours and beta score 
            #plt.scatter(rev_x[:,0], rev_x[:,1], s=1., c='red')
            #plt.scatter(lalinf_mc, lalinf_q, s=1., c='blue')
        

    except KeyboardInterrupt:
        pass
    finally:
        print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")
Exemple #4
0
def main():

    # Set up simulation parameters
    batch_size = 1600  # set batch size
    r = 3  # the grid dimension for the output tests
    test_split = r * r  # number of testing samples to use
    sig_model = 'sg'  # the signal model to use
    sigma = 0.2  # the noise std
    ndata = 32  # number of data samples
    bound = [0.0, 1.0, 0.0, 1.0]  # effective bound for likelihood
    seed = 1  # seed for generating data

    # generate data
    pos, labels, x, sig = data.generate(model=sig_model,
                                        tot_dataset_size=2**20,
                                        ndata=ndata,
                                        sigma=sigma,
                                        prior_bound=bound,
                                        seed=seed)

    # seperate the test data for plotting
    pos_test = pos[-test_split:]
    labels_test = labels[-test_split:]
    sig_test = sig[-test_split:]

    # plot the test data examples
    plt.figure(figsize=(6, 6))
    fig, axes = plt.subplots(r, r, figsize=(6, 6))
    cnt = 0
    for i in range(r):
        for j in range(r):
            axes[i, j].plot(x, np.array(labels_test[cnt, :]), '.')
            axes[i, j].plot(x, np.array(sig_test[cnt, :]), '-')
            cnt += 1
            axes[i, j].axis([0, 1, -1.5, 1.5])
    plt.savefig('/data/public_html/chrism/FrEIA/test_distribution.png',
                dpi=360)
    plt.close()

    # setting up the model
    ndim_x = 2  # number of posterior parameter dimensions (x,y)
    ndim_y = ndata  # number of label dimensions (noisy data samples)
    ndim_z = 8  # number of latent space dimensions?
    ndim_tot = max(ndim_x,
                   ndim_y + ndim_z)  # must be > ndim_x and > ndim_y + ndim_z

    # define different parts of the network
    # define input node
    inp = InputNode(ndim_tot, name='input')

    # define hidden layer nodes
    t1 = Node([inp.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.2
        }
    })

    t2 = Node([t1.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.2
        }
    })

    t3 = Node([t2.out0], rev_multiplicative_layer, {
        'F_class': F_fully_connected,
        'clamp': 2.0,
        'F_args': {
            'dropout': 0.2
        }
    })

    # define output layer node
    outp = OutputNode([t3.out0], name='output')

    nodes = [inp, t1, t2, t3, outp]
    model = ReversibleGraphNet(nodes)

    # Train model
    # Training parameters
    n_epochs = 1000
    meta_epoch = 12  # what is this???
    n_its_per_epoch = 12
    batch_size = 1600

    lr = 1e-2
    gamma = 0.01**(1. / 120)
    l2_reg = 2e-5

    y_noise_scale = 3e-2
    zeros_noise_scale = 3e-2

    # relative weighting of losses:
    lambd_predict = 300.  # forward pass
    lambd_latent = 300.  # laten space
    lambd_rev = 400.  # backwards pass

    # padding both the data and the latent space
    # such that they have equal dimension to the parameter space
    #pad_x = torch.zeros(batch_size, ndim_tot - ndim_x)
    #pad_yz = torch.zeros(batch_size, ndim_tot - ndim_y - ndim_z)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=(0.8, 0.8),
                                 eps=1e-04,
                                 weight_decay=l2_reg)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=meta_epoch,
                                                gamma=gamma)

    # define the three loss functions
    loss_backward = MMD_multiscale
    loss_latent = MMD_multiscale
    loss_fit = fit

    # set up training set data loader
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(
        pos[test_split:], labels[test_split:]),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True)

    # initialisation of network weights
    for mod_list in model.children():
        for block in mod_list.children():
            for coeff in block.children():
                coeff.fc3.weight.data = 0.01 * torch.randn(
                    coeff.fc3.weight.shape)
    model.to(device)

    # initialize plot for showing testing results
    fig, axes = plt.subplots(r, r, figsize=(6, 6))

    # number of test samples to use after training
    N_samp = 256

    # precompute true likelihood on the test data
    Ngrid = 64
    cnt = 0
    lik = np.zeros((r, r, Ngrid * Ngrid))
    for i in range(r):
        for j in range(r):
            mvec, cvec, temp = data.get_lik(np.array(
                labels_test[cnt, :]).flatten(),
                                            n_grid=Ngrid,
                                            sig_model=sig_model,
                                            sigma=sigma,
                                            xvec=x,
                                            bound=bound)
            lik[i, j, :] = temp.flatten()
            cnt += 1

    # start training loop
    try:
        t_start = time()
        # loop over number of epochs
        for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80):

            scheduler.step()

            # Initially, the l2 reg. on x and z can give huge gradients, set
            # the lr lower for this
            if i_epoch < 0:
                print('inside this iepoch<0 thing')
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr * 1e-2

            # train the model
            train(model, train_loader, n_its_per_epoch, zeros_noise_scale,
                  batch_size, ndim_tot, ndim_x, ndim_y, ndim_z, y_noise_scale,
                  optimizer, lambd_predict, loss_fit, lambd_latent,
                  loss_latent, lambd_rev, loss_backward, i_epoch)

            # loop over a few cases and plot results in a grid
            cnt = 0
            for i in range(r):
                for j in range(r):

                    # convert data into correct format
                    y_samps = np.tile(np.array(labels_test[cnt, :]),
                                      N_samp).reshape(N_samp, ndim_y)
                    y_samps = torch.tensor(y_samps, dtype=torch.float)
                    #y_samps += y_noise_scale * torch.randn(N_samp, ndim_y)
                    y_samps = torch.cat(
                        [
                            torch.randn(N_samp, ndim_z),  #zeros_noise_scale * 
                            torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z),
                            y_samps
                        ],
                        dim=1)
                    y_samps = y_samps.to(device)

                    # use the network to predict parameters
                    rev_x = model(y_samps, rev=True)
                    rev_x = rev_x.cpu().data.numpy()

                    # plot the samples and the true contours
                    axes[i, j].clear()
                    axes[i, j].contour(mvec,
                                       cvec,
                                       lik[i, j, :].reshape(Ngrid, Ngrid),
                                       levels=[0.68, 0.9, 0.99])
                    axes[i, j].scatter(rev_x[:, 0],
                                       rev_x[:, 1],
                                       s=0.5,
                                       alpha=0.5)
                    axes[i, j].plot(pos_test[cnt, 0],
                                    pos_test[cnt, 1],
                                    '+r',
                                    markersize=8)
                    axes[i, j].axis(bound)

                    cnt += 1

            # sve the results to file
            fig.canvas.draw()
            plt.savefig('/data/public_html/chrism/FrEIA/posteriors_%s.png' %
                        i_epoch,
                        dpi=360)
            plt.savefig('/data/public_html/chrism/FrEIA/latest.png', dpi=360)

    except KeyboardInterrupt:
        pass
    finally:
        print("\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")