Ejemplo n.º 1
0
print(minimum)
print('\nminimized\n')

tf.reset_default_graph()

tfic = linear_field(
    FLAGS.nc, FLAGS.box_size, ipklin, batch_size=1, seed=100,
    dtype=dtype) * 0 + minimum.reshape(data_noised.shape)
state = lpt_init(tfic, a0=0.1, order=1)
final_state = nbody(state, stages, FLAGS.nc)
tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
with tf.Session() as sess:
    minic, minfin = sess.run([tfic, tfinal_field])

dg.saveimfig(0, [minic, minfin], [ic, fin], fpath + '')
dg.save2ptfig(0, [minic, minfin], [ic, fin], fpath + '', bs)

##
##
##def main(_):
##
##    dtype=tf.float32
##
##    startw = time.time()
##
##    tf.random.set_random_seed(100)
##    np.random.seed(100)
##
##
##    # Compute a few things first, using simple tensorflow
##    a0=FLAGS.a0
Ejemplo n.º 2
0
def main():

    startw = time.time()

    # Run normal flowpm to generate data
    try:
        ic, fin = np.load(fpath + 'ic.npy'), np.load(fpath + 'final.npy')
        print('Data loaded')
    except Exception as e:
        print('Exception occured', e)
        tfic = linear_field(nc,
                            bs,
                            ipklin,
                            batch_size=1,
                            seed=100,
                            dtype=dtype)
        tfinal_field = pm(tfic)
        ic, fin = tfic.numpy(), tfinal_field.numpy()
        np.save(fpath + 'ic', ic)
        np.save(fpath + 'final', fin)

    print('\ndata constructed\n')

    noise = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(np.float32)
    data_noised = fin + noise
    data = data_noised

    @tf.function
    def recon_prototype(linear, Rsm):
        """
        """

        linear = tf.reshape(linear, data.shape)
        #loss = tf.reduce_sum(tf.square(linear - minimum))
        final_field = pm(linear)

        residual = final_field - data.astype(np.float32)
        base = residual

        if anneal:
            print("\nAdd annealing section to graph\n")
            Rsmsq = tf.multiply(Rsm * bs / nc, Rsm * bs / nc)
            smwts = tf.exp(tf.multiply(-kmesh**2, Rsmsq))
            basek = r2c3d(base, norm=nc**3)
            basek = tf.multiply(basek, tf.cast(smwts, tf.complex64))
            base = c2r3d(basek, norm=nc**3)

        chisq = tf.multiply(base, base)
        chisq = tf.reduce_sum(chisq)
        chisq = tf.multiply(chisq, 1 / nc**3, name='chisq')

        #Prior
        lineark = r2c3d(linear, norm=nc**3)
        priormesh = tf.square(tf.cast(tf.abs(lineark), tf.float32))
        prior = tf.reduce_sum(tf.multiply(priormesh, 1 / priorwt))
        prior = tf.multiply(prior, 1 / nc**3, name='prior')
        #
        loss = chisq + prior

        return loss

    @tf.function
    def val_and_grad(x, Rsm):
        print("val and grad : ", x.shape)
        with tf.GradientTape() as tape:
            tape.watch(x)
            loss = recon_prototype(x, tf.constant(Rsm, dtype=tf.float32))
        grad = tape.gradient(loss, x)
        return loss, grad

    @tf.function
    def val_and_grad(x, Rsm):
        print("val and grad : ", x.shape)
        with tf.GradientTape() as tape:
            tape.watch(x)
            loss = recon_prototype(x, Rsm)
        grad = tape.gradient(loss, x)
        return loss, grad

    @tf.function
    def grad(x, Rsm):
        with tf.GradientTape() as tape:
            tape.watch(x)
            loss = recon_prototype(x, Rsm)
        grad = tape.gradient(loss, x)
        return grad

    #Function for LBFSG
    def func(x, RR):
        return [
            vv.numpy().astype(np.float64)
            for vv in val_and_grad(x=tf.constant(x, dtype=tf.float32),
                                   Rsm=tf.constant(RR, dtype=tf.float32))
        ]  #

    # Create an optimizer for Adam.
    opt = tf.keras.optimizers.Adam(learning_rate=lr)

    #Loop it Reconstruction
    ##Reconstruction
    x0 = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(np.float32)
    linear = tf.Variable(name='linmesh',
                         shape=(1, nc, nc, nc),
                         dtype=tf.float32,
                         initial_value=x0,
                         trainable=True)

    for iR, RR in enumerate(RRs):

        if optimizer == 'lbfgs':
            results = sopt.minimize(fun=func,
                                    x0=x0,
                                    args=RR,
                                    jac=True,
                                    method='L-BFGS-B',
                                    tol=1e-10,
                                    options={
                                        'maxiter': niter,
                                        'ftol': 1e-12,
                                        'gtol': 1e-12,
                                        'eps': 1e-12
                                    })
            #results = sopt.minimize(fun=func, x0=x0, args = RR, jac=True, method='L-BFGS-B',
            #                    options={'maxiter':niter})
            print(results)
            minic = results.x.reshape(data.shape)

        elif optimizer == 'adam':
            for i in range(niter):
                grads = grad([linear], tf.constant(RR, dtype=tf.float32))
                opt.apply_gradients(zip(grads, [linear]))
            minic = linear.numpy().reshape(data.shape)

        #
        print('\nminimized\n')
        minfin = pm(tf.constant(minic, dtype=tf.float32)).numpy()
        dg.saveimfig("-R%d" % RR, [minic, minfin], [ic, fin], fpath + '')
        dg.save2ptfig("-R%d" % RR, [minic, minfin], [ic, fin], fpath + '', bs)
        ###
        x0 = minic
        np.save(fpath + 'ic-%d' % iR, minic)
        np.save(fpath + 'final-%d' % iR, minfin)

    exit(0)
Ejemplo n.º 3
0
def main(_):

    infield = True
    dtype = tf.float32
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    nc, bs = FLAGS.nc, FLAGS.box_size
    a0, a, nsteps = FLAGS.a0, FLAGS.af, FLAGS.nsteps
    stages = np.linspace(a0, a, nsteps, endpoint=True)
    numd = 1e-3

    ##Begin here
    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)

    #pypath = '/global/cscratch1/sd/chmodi/cosmo4d/output/version2/L0400_N0128_05step-fof/lhd_S0100/n10/opt_s999_iM12-sm3v25off/meshes/'
    final = tools.readbigfile('../data//L0400_N0128_S0100_05step/mesh/d/')
    ic = tools.readbigfile('../data/L0400_N0128_S0100_05step/mesh/s/')
    fpos = tools.readbigfile(
        '../data/L0400_N0128_S0100_05step/dynamic/1/Position/')

    hpos = tools.readbigfile(
        '../data/L0400_N0512_S0100_40step/FOF/PeakPosition//')[1:int(bs**3 *
                                                                     numd)]
    hmass = tools.readbigfile(
        '../data/L0400_N0512_S0100_40step/FOF/Mass//')[1:int(bs**3 *
                                                             numd)].flatten()

    meshpos = tools.paintcic(hpos, bs, nc)
    meshmass = tools.paintcic(hpos, bs, nc, hmass.flatten() * 1e10)
    data = meshmass
    data /= data.mean()
    data -= 1
    kv = tools.fftk([nc, nc, nc], bs, symmetric=True, dtype=np.float32)
    datasm = tools.fingauss(data, kv, 3, np.pi * nc / bs)
    ic, data = np.expand_dims(ic, 0), np.expand_dims(data,
                                                     0).astype(np.float32)
    datasm = np.expand_dims(datasm, 0).astype(np.float32)
    print("Min in data : %0.4e" % datasm.min())

    np.save(fpath + 'ic', ic)
    np.save(fpath + 'data', data)

    ####################################################
    #
    tf.reset_default_graph()
    tfic = tf.constant(ic.astype(np.float32))
    state = lpt_init(tfic, a0=0.1, order=1)
    final_state = nbody(state, stages, FLAGS.nc)
    tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
    with tf.Session() as sess:
        state = sess.run(final_state)

    fpos = state[0, 0] * bs / nc
    bparams, bmodel = getbias(bs, nc, data[0] + 1, ic[0], fpos)
    #bmodel += 1 #np.expand_dims(bmodel, 0) + 1
    errormesh = data - np.expand_dims(bmodel, 0)
    kerror, perror = tools.power(errormesh[0] + 1, boxsize=bs)
    kerror, perror = kerror[1:], perror[1:]
    print("Error power spectra", kerror, perror)
    print("\nkerror", kerror.min(), kerror.max(), "\n")
    print("\nperror", perror.min(), perror.max(), "\n")
    suff = "-error"
    dg.saveimfig(suff, [ic, errormesh], [ic, data], fpath + '/figs/')
    dg.save2ptfig(suff, [ic, errormesh], [ic, data], fpath + '/figs/', bs)
    ipkerror = iuspline(kerror, perror)

    ####################################################

    #stdinit = srecon.standardinit(bs, nc, meshpos, hpos, final, R=8)

    recon_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                             model_dir=fpath)

    def predict_input_fn(data=data,
                         M0=0.,
                         w=3.,
                         R0=0.,
                         off=None,
                         istd=None,
                         x0=None):
        features = {}
        features['datasm'] = data
        features['R0'] = R0
        features['x0'] = x0
        features['bparams'] = bparams
        features['ipkerror'] = [kerror, perror]  #ipkerror
        return features, None

    eval_results = recon_estimator.predict(
        input_fn=lambda: predict_input_fn(x0=ic), yield_single_examples=False)

    for i, pred in enumerate(eval_results):
        if i > 0: break

    suff = '-model'
    dg.saveimfig(suff, [pred['ic'], pred['model']], [ic, data],
                 fpath + '/figs/')
    dg.save2ptfig(suff, [pred['ic'], pred['model']], [ic, data],
                  fpath + '/figs/', bs)
    np.save(fpath + '/reconmeshes/ic_true' + suff, pred['ic'])
    np.save(fpath + '/reconmeshes/fin_true' + suff, pred['final'])
    np.save(fpath + '/reconmeshes/model_true' + suff, pred['model'])

    #
    randominit = np.random.normal(size=data.size).reshape(data.shape)
    #eval_results = recon_estimator.predict(input_fn=lambda : predict_input_fn(x0 = np.expand_dims(stdinit, 0)), yield_single_examples=False)
    eval_results = recon_estimator.predict(
        input_fn=lambda: predict_input_fn(x0=randominit),
        yield_single_examples=False)

    for i, pred in enumerate(eval_results):
        if i > 0: break

    suff = '-init'
    dg.saveimfig(suff, [pred['ic'], pred['model']], [ic, data],
                 fpath + '/figs/')
    dg.save2ptfig(suff, [pred['ic'], pred['model']], [ic, data],
                  fpath + '/figs/', bs)
    np.save(fpath + '/reconmeshes/ic_init' + suff, pred['ic'])
    np.save(fpath + '/reconmeshes/fin_init' + suff, pred['final'])
    np.save(fpath + '/reconmeshes/model_init' + suff, pred['model'])

    #
    # Train and evaluate model.
    RRs = [4., 2., 1., 0.5, 0.]
    niter = 100
    iiter = 0

    for R0 in RRs:

        print('\nFor iteration %d\n' % iiter)
        print('With  R0=%0.2f \n' % (R0))

        def train_input_fn():
            features = {}
            features['datasm'] = data
            features['R0'] = R0
            features['bparams'] = bparams
            features['ipkerror'] = [kerror, perror]  #ipkerror
            #features['x0'] = np.expand_dims(stdinit, 0)
            features['x0'] = randominit
            features['lr'] = 0.01
            return features, None

        recon_estimator.train(input_fn=train_input_fn, max_steps=iiter + niter)
        eval_results = recon_estimator.predict(input_fn=predict_input_fn,
                                               yield_single_examples=False)

        for i, pred in enumerate(eval_results):
            if i > 0: break

        iiter += niter  #
        suff = '-%d-R%d' % (iiter, R0)
        dg.saveimfig(suff, [pred['ic'], pred['model']], [ic, data],
                     fpath + '/figs/')
        dg.save2ptfig(suff, [pred['ic'], pred['model']], [ic, data],
                      fpath + '/figs/', bs)
        np.save(fpath + '/reconmeshes/ic' + suff, pred['ic'])
        np.save(fpath + '/reconmeshes/fin' + suff, pred['final'])
        np.save(fpath + '/reconmeshes/model' + suff, pred['model'])

    sys.exit(0)

    ##
    exit(0)
Ejemplo n.º 4
0
def main(_):

    dtype = tf.float32

    startw = time.time()

    tf.random.set_random_seed(100)
    np.random.seed(100)

    # Compute a few things first, using simple tensorflow
    a0 = FLAGS.a0
    a = FLAGS.af
    nsteps = FLAGS.nsteps
    bs, nc = FLAGS.box_size, FLAGS.nc
    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    tf.reset_default_graph()
    # Run normal flowpm to generate data
    try:
        ic, fin = np.load(fpath + 'ic.npy'), np.load(fpath + 'final.npy')
        print('Data loaded')
    except Exception as e:
        print('Exception occured', e)
        tfic = linear_field(FLAGS.nc,
                            FLAGS.box_size,
                            ipklin,
                            batch_size=1,
                            seed=100,
                            dtype=dtype)
        if FLAGS.nbody:
            state = lpt_init(tfic, a0=0.1, order=1)
            final_state = nbody(state, stages, FLAGS.nc)
        else:
            final_state = lpt_init(tfic, a0=stages[-1], order=1)
        tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
        with tf.Session() as sess:
            ic, fin = sess.run([tfic, tfinal_field])
        np.save(fpath + 'ic', ic)
        np.save(fpath + 'final', fin)

    tf.reset_default_graph()
    print('ic constructed')

    linear, final_field, update_ops, loss, chisq, prior, Rsm = recon_prototype(
        fin)

    #initial_conditions = recon_prototype(mesh, fin, nc=FLAGS.nc,  batch_size=FLAGS.batch_size, dtype=dtype)

    # Lower mesh computation

    with tf.Session() as sess:

        #ic_check, fin_check = sess.run([tf_initc, tf_final])
        #sess.run(tf_linear_op, feed_dict={input_field:ic})
        #ic_check, fin_check = sess.run([linear, final_field])
        #dg.saveimfig('-check', [ic_check, fin_check], [ic, fin], fpath)
        #dg.save2ptfig('-check', [ic_check, fin_check], [ic, fin], fpath, bs)

        #sess.run(tf_linear_op, feed_dict={input_field:np.random.normal(size=ic.size).reshape(ic.shape)})
        sess.run(tf.global_variables_initializer())
        ic0, fin0 = sess.run([linear, final_field])
        dg.saveimfig('-init', [ic0, fin0], [ic, fin], fpath)
        start = time.time()

        titer = 20
        niter = 201
        iiter = 0

        start0 = time.time()
        RRs = [4, 2, 1, 0.5, 0]
        lrs = np.array([0.1, 0.1, 0.1, 0.1, 0.1]) * 2
        #lrs = [0.1, 0.05, 0.01, 0.005, 0.001]
        for iR, zlR in enumerate(zip(RRs, lrs)):
            RR, lR = zlR
            for ff in [fpath + '/figs-R%02d' % (10 * RR)]:
                try:
                    os.makedirs(ff)
                except Exception as e:
                    print(e)

            for i in range(niter):
                iiter += 1
                sess.run(update_ops, {Rsm: RR})
                print(sess.run([loss, chisq, prior], {Rsm: RR}))
                if (i % titer == 0):
                    end = time.time()
                    print('Iter : ', i)
                    print('Time taken for %d iterations: ' % titer,
                          end - start)
                    start = end

                    ##
                    #ic1, fin1, cc, pp = sess.run([tf_initc, tf_final, tf_chisq, tf_prior], {R0:RR})
                    #ic1, fin1, cc, pp = sess.run([tf_initc, tf_final, tf_chisq, tf_prior], {R0:RR})
                    ic1, fin1 = sess.run([linear, final_field])
                    #print('Chisq and prior are : ', cc, pp)

                    dg.saveimfig(i, [ic1, fin1], [ic, fin],
                                 fpath + '/figs-R%02d' % (10 * RR))
                    dg.save2ptfig(i, [ic1, fin1], [ic, fin],
                                  fpath + '/figs-R%02d' % (10 * RR), bs)
            dg.saveimfig(i * (iR + 1), [ic1, fin1], [ic, fin], fpath + '/figs')
            dg.save2ptfig(i * (iR + 1), [ic1, fin1], [ic, fin],
                          fpath + '/figs', bs)

        ic1, fin1 = sess.run([linear, final_field])
        print('Total time taken for %d iterations is : ' % iiter,
              time.time() - start0)

    dg.saveimfig(i, [ic1, fin1], [ic, fin], fpath)
    dg.save2ptfig(i, [ic1, fin1], [ic, fin], fpath, bs)

    np.save(fpath + 'ic_recon', ic1)
    np.save(fpath + 'final_recon', fin1)
    print('Total wallclock time is : ', time.time() - start0)

    ##
    exit(0)
Ejemplo n.º 5
0
def main(_):

    dtype=tf.float32

    startw = time.time()

    tf.random.set_random_seed(100)
    np.random.seed(100)

    
    # Compute a few things first, using simple tensorflow
    a0=FLAGS.a0
    a=FLAGS.af
    nsteps=FLAGS.nsteps
    bs, nc = FLAGS.box_size, FLAGS.nc
    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)
 
 
    tf.reset_default_graph()
    # Run normal flowpm to generate data
    ic = np.load('../data/poisson_L%04d_N%03d/ic.npy'%(bs, nc))
    fin = np.load('../data/poisson_L%04d_N%03d/final.npy'%(bs, nc))
    data = np.load('../data/poisson_L%04d_N%03d/psample_%0.2f.npy'%(bs, nc, plambda))


    ################################################################
    tf.reset_default_graph()
    print('ic constructed')

    startpos = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(np.float32)*1
    startpos = startpos.flatten()


    ##
    Rsm = tf.placeholder(tf.float32, name='smoothing')
    def recon_prototype(x0=None):
        """
        """       
#        linear = tf.get_variable('linmesh', shape=(1, nc, nc, nc), dtype=tf.float32,
#                             initializer=tf.random_normal_initializer(), trainable=True)
        if x0 is None:
            linear = tf.get_variable('linmesh', shape=(1, nc, nc, nc), dtype=tf.float32,
                             initializer=tf.random_normal_initializer(), trainable=True)
        else:
            linear = tf.get_variable('linmesh', shape=(1, nc, nc, nc), dtype=tf.float32,
                             initializer=tf.constant_initializer(x0), trainable=True)

        state = lpt_init(linear, a0=0.1, order=1)
        final_state = nbody(state,  stages, FLAGS.nc)
        final_field = cic_paint(tf.zeros_like(linear), final_state[0])
        base = final_field

        if FLAGS.anneal:
            print('\nAdd annealing graph\n')
            Rsmsq = tf.multiply(Rsm*bs/nc, Rsm*bs/nc)
            smwts = tf.exp(tf.multiply(-kmesh**2, Rsmsq))
            basek = r2c3d(base, norm=nc**3)
            basek = tf.multiply(basek, tf.cast(smwts, tf.complex64))
            base = c2r3d(basek, norm=nc**3)   

        galmean = tfp.distributions.Poisson(rate = plambda * (1 + base))
        sample = galmean.sample()
        logprob = -tf.reduce_sum(galmean.log_prob(data))
        #logprob = tf.multiply(logprob, 1/nc**3, name='logprob')
        
        #Prior
        lineark = r2c3d(linear, norm=nc**3)
        priormesh = tf.square(tf.cast(tf.abs(lineark), tf.float32))
        prior = tf.reduce_sum(tf.multiply(priormesh, 1/priorwt))
        #prior = tf.multiply(prior, 1/nc**3, name='prior')
        #
        loss = logprob + prior

        #opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)
        opt = tf.train.AdamOptimizer(learning_rate=FLAGS.lr)

        #step = tf.Variable(0, trainable=False)
        #schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
        #    [10000, 15000], [1e-0, 1e-1, 1e-2])
        ## lr and wd can be a function or a tensor
        #lr = 1e-1 * schedule(step)
        #wd = lambda: 1e-4 * schedule(step)
        #opt = tfa.optimizers.AdamW(learning_rate=FLAGS.lr, weight_decay=1e-1)

        # Compute the gradients for a list of variables.
        grads_and_vars = opt.compute_gradients(loss, [linear])
        print("\ngradients : ", grads_and_vars)
        update_ops = opt.apply_gradients(grads_and_vars)

        return linear, sample, update_ops, loss, logprob, prior
 

    linear, sample, update_ops, loss, chisq, prior = recon_prototype()


    with tf.Session() as sess:
                    
        #sess.run(tf_linear_op, feed_dict={input_field:np.random.normal(size=ic.size).reshape(ic.shape)})
        sess.run(tf.global_variables_initializer())
        ic0, sampl0 = sess.run([linear, sample], {Rsm:0.})
        dg.saveimfig('-init', [ic0, sampl0], [ic, data], fpath)
        start = time.time()


        titer = 10
        niter = FLAGS.niter + 1
        iiter = 0
        
        start0 = time.time()
        RRs = [2, 1, 0.5, 0]
        lrs = np.array([0.1, 0.1, 0.1, 0.1, 0.1])*2
        for iR, zlR in enumerate(zip(RRs, lrs)):
            RR, lR = zlR
            for ff in [fpath + '/figs-R%02d'%(10*RR)]:
                try: os.makedirs(ff)
                except Exception as e: print (e)
                
            for i in range(niter):
                iiter +=1
                sess.run(update_ops, {Rsm:RR})
                #
                if (i%titer == 0):
                    end = time.time()
                    print('Iter : ', i)
                    print('Time taken for %d iterations: '%titer, end-start)
                    start = end
                    print(sess.run([loss, chisq, prior], {Rsm:RR}))
                if (i%(2*titer) == 0):
                    ic1, samp1 = sess.run([linear, sample], {Rsm:RR})                
                    dg.saveimfig(i, [ic1, samp1], [ic, data], fpath+'/figs-R%02d'%(10*RR))
                    dg.save2ptfig(i, [ic1, samp1], [ic, data], fpath+'/figs-R%02d'%(10*RR), bs)

        ic1, samp1 = sess.run([linear, sample], {Rsm:0.})
        print('Total time taken for %d iterations is : '%iiter, time.time()-start0)
        
    dg.saveimfig(i, [ic1, samp1], [ic, data], fpath)
    dg.save2ptfig(i, [ic1, samp1], [ic, data], fpath, bs)

    np.save(fpath + 'ic_recon', ic1)
    np.save(fpath + 'final_recon', samp1)
    print('Total wallclock time is : ', time.time()-start0)


    
##
    exit(0)
Ejemplo n.º 6
0
def main(_):

    dtype = tf.float32
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)

    startw = time.time()

    print(mesh_shape)
    ##
    ##
    ##Begin here
    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('..//data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)

    tf.reset_default_graph()
    # Run normal flowpm to generate data
    plambda = FLAGS.plambda
    ic, fin = np.load('../data/poisson_N%03d/ic.npy' % nc), np.load(
        '../data/poisson_N%03d/psample_%0.2f.npy' % (nc, plambda))
    print('Data loaded')

    ########################################################
    print(ic.shape, fin.shape)
    recon_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                             model_dir=fpath)

    def eval_input_fn():
        features = {}
        features['data'] = fin
        features['R0'] = 0
        features['x0'] = None
        features['lr'] = 0
        return features, None

    # Train and evaluate model.

    RRs = [4., 2., 1., 0.5, 0.]
    niter = 200
    iiter = 0

    for R0 in RRs:
        print('\nFor iteration %d and R=%0.1f\n' % (iiter, R0))

        def train_input_fn():
            features = {}
            features['data'] = fin
            features['R0'] = R0
            features['x0'] = np.random.normal(size=fin.size).reshape(fin.shape)
            features['lr'] = 0.01
            return features, None

        for _ in range(1):
            recon_estimator.train(input_fn=train_input_fn,
                                  max_steps=iiter + niter)
            eval_results = recon_estimator.predict(input_fn=eval_input_fn,
                                                   yield_single_examples=False)

            for i, pred in enumerate(eval_results):
                if i > 0: break

            iiter += niter  #
            dg.saveimfig(iiter, [pred['ic'], pred['data']], [ic, fin],
                         fpath + '/figs/')
            dg.save2ptfig(iiter, [pred['ic'], pred['data']], [ic, fin],
                          fpath + '/figs/', bs)

    sys.exit(0)
Ejemplo n.º 7
0
def main(_):

    dtype = tf.float32
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)

    startw = time.time()

    print(mesh_shape)
    ##
    ##
    ##Begin here
    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('..//data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)

    tf.reset_default_graph()
    # Run normal flowpm to generate data
    try:
        ic, fin = np.load(fpath + 'ic.npy'), np.load(fpath + 'final.npy')
        print('Data loaded')
    except Exception as e:
        print('Exception occured', e)
        tfic = linear_field(FLAGS.nc,
                            FLAGS.box_size,
                            ipklin,
                            batch_size=1,
                            seed=100,
                            dtype=dtype)
        if FLAGS.nbody:
            state = lpt_init(tfic, a0=0.1, order=1)
            final_state = nbody(state, stages, FLAGS.nc)
        else:
            final_state = lpt_init(tfic, a0=stages[-1], order=1)
        tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
        with tf.Session() as sess:
            ic, fin = sess.run([tfic, tfinal_field])
        np.save(fpath + 'ic', ic)
        np.save(fpath + 'final', fin)

    print(ic.shape, fin.shape)
    ########################################################
    print(ic.shape, fin.shape)
    recon_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                             model_dir=fpath)

    def eval_input_fn():
        features = {}
        features['data'] = fin
        features['R0'] = 0
        features['x0'] = None
        features['lr'] = 0
        return features, None

    # Train and evaluate model.

    RRs = [4., 2., 1., 0.5, 0.]
    niter = 200
    iiter = 0

    for R0 in RRs:
        print('\nFor iteration %d and R=%0.1f\n' % (iiter, R0))

        def train_input_fn():
            features = {}
            features['data'] = fin
            features['R0'] = R0
            features['x0'] = np.random.normal(size=fin.size).reshape(fin.shape)
            features['lr'] = 0.01
            return features, None

        for _ in range(1):
            recon_estimator.train(input_fn=train_input_fn,
                                  max_steps=iiter + niter)
            eval_results = recon_estimator.predict(input_fn=eval_input_fn,
                                                   yield_single_examples=False)

            for i, pred in enumerate(eval_results):
                if i > 0: break

            iiter += niter  #
            dg.saveimfig(iiter, [pred['ic'], pred['data']], [ic, fin],
                         fpath + '/figs/')
            dg.save2ptfig(iiter, [pred['ic'], pred['data']], [ic, fin],
                          fpath + '/figs/', bs)

    sys.exit(0)
Ejemplo n.º 8
0
def main(_):

    dtype = tf.float32

    startw = time.time()

    tf.random.set_random_seed(100)
    np.random.seed(100)

    # Compute a few things first, using simple tensorflow
    a0 = FLAGS.a0
    a = FLAGS.af
    nsteps = FLAGS.nsteps
    bs, nc = FLAGS.box_size, FLAGS.nc
    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    tf.reset_default_graph()
    # Run normal flowpm to generate data
    try:
        ic, fin = np.load(fpath + 'ic.npy'), np.load(fpath + 'final.npy')
        print('Data loaded')
    except Exception as e:
        print('Exception occured', e)
        tfic = linear_field(FLAGS.nc,
                            FLAGS.box_size,
                            ipklin,
                            batch_size=1,
                            seed=100,
                            dtype=dtype)
        if FLAGS.nbody:
            state = lpt_init(tfic, a0=0.1, order=1)
            final_state = nbody(state, stages, FLAGS.nc)
        else:
            final_state = lpt_init(tfic, a0=stages[-1], order=1)
        tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
        with tf.Session() as sess:
            ic, fin = sess.run([tfic, tfinal_field])
        np.save(fpath + 'ic', ic)
        np.save(fpath + 'final', fin)

    k, pic = tools.power(ic[0] + 1, boxsize=bs)
    k, pfin = tools.power(fin[0], boxsize=bs)
    plt.plot(k, pic)
    plt.plot(k, pfin)
    plt.loglog()
    plt.grid(which='both')
    plt.savefig('pklin.png')
    plt.close()

    print(pic)
    print(pfin)
    #sys.exit(-1)

    ################################################################
    tf.reset_default_graph()
    print('ic constructed')

    noise = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(
        np.float32) * 1
    data_noised = fin + noise
    data = data_noised

    startpos = noise.copy().flatten().astype(np.float32)

    x0 = tf.placeholder(dtype=tf.float32,
                        shape=data.flatten().shape,
                        name='initlin')
    Rsm = tf.placeholder(tf.float32, name='smoothing')

    def recon_prototype(linearflat):
        """
        """

        linear = tf.reshape(linearflat, data.shape)
        #

        #loss = tf.reduce_sum(tf.square(linear - minimum))

        state = lpt_init(linear, a0=0.1, order=1)
        final_state = nbody(state, stages, FLAGS.nc)
        final_field = cic_paint(tf.zeros_like(linear), final_state[0])

        residual = final_field - data.astype(np.float32)
        base = residual
        Rsmsq = tf.multiply(Rsm * bs / nc, Rsm * bs / nc)
        smwts = tf.exp(tf.multiply(-kmesh**2, Rsmsq))
        basek = r2c3d(base, norm=nc**3)
        basek = tf.multiply(basek, tf.cast(smwts, tf.complex64))
        base = c2r3d(basek, norm=nc**3)
        #
        chisq = tf.multiply(base, base)
        chisq = tf.reduce_sum(chisq)
        chisq = tf.multiply(chisq, 1 / nc**3, name='chisq')

        #Prior
        lineark = r2c3d(linear, norm=nc**3)
        priormesh = tf.square(tf.cast(tf.abs(lineark), tf.float32))
        prior = tf.reduce_sum(tf.multiply(priormesh, 1 / priorwt))
        prior = tf.multiply(prior, 1 / nc**3, name='prior')
        #
        loss = chisq + prior

        grad = tf.gradients(loss, linearflat)
        print(grad)
        return loss, grad[0]

    @tf.function
    def min_lbfgs():
        return tfp.optimizer.lbfgs_minimize(
            #make_val_and_grad_fn(recon_prototype),
            recon_prototype,
            initial_position=x0,
            tolerance=1e-10,
            max_iterations=100)

    with tf.Session() as sess:
        start = time.time()
        results = sess.run(min_lbfgs(), {Rsm: 2, x0: startpos})
        print("\n")
        print(results)
        print("\n")
        minimum = results.position
        print(minimum)
        print("\nTime taken : ", time.time() - start)

        start = time.time()
        results = sess.run(min_lbfgs(), {Rsm: 1, x0: minimum})
        print("\n")
        print(results)
        minimum = results.position
        print("\n")
        print(minimum)
        print("\nTime taken : ", time.time() - start)

        start = time.time()
        results = sess.run(min_lbfgs(), {Rsm: 0, x0: minimum})
        print("\n")
        print(results)
        minimum = results.position
        print("\n")
        print(minimum)
        print("\nTime taken : ", time.time() - start)

    tf.reset_default_graph()
    print("\n")
    print('\nminimized\n')

    tfic = linear_field(
        FLAGS.nc, FLAGS.box_size, ipklin, batch_size=1, seed=100,
        dtype=dtype) * 0 + minimum.reshape(data_noised.shape)
    state = lpt_init(tfic, a0=0.1, order=1)
    final_state = nbody(state, stages, FLAGS.nc)
    tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
    with tf.Session() as sess:
        minic, minfin = sess.run([tfic, tfinal_field])

    dg.saveimfig(0, [minic, minfin], [ic, fin], fpath + '')
    dg.save2ptfig(0, [minic, minfin], [ic, fin], fpath + '', bs)

    np.save(fpath + 'recon0ic', minic)
    np.save(fpath + 'recon-final', minfin)

    ##
    exit(0)
Ejemplo n.º 9
0
def main(_):

    dtype = tf.float32

    startw = time.time()

    tf.random.set_random_seed(100)
    np.random.seed(100)

    # Compute a few things first, using simple tensorflow
    a0 = FLAGS.a0
    a = FLAGS.af
    nsteps = FLAGS.nsteps
    bs, nc = FLAGS.box_size, FLAGS.nc
    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    tf.reset_default_graph()
    # Run normal flowpm to generate data
    ic = np.load('../data/poisson_L%04d_N%03d/ic.npy' % (bs, nc))
    fin = np.load('../data/poisson_L%04d_N%03d/final.npy' % (bs, nc))
    data = np.load('../data/poisson_L%04d_N%03d/psample_%0.2f.npy' %
                   (bs, nc, plambda))

    k, pic = tools.power(ic[0] + 1, boxsize=bs)
    k, pfin = tools.power(fin[0], boxsize=bs)
    plt.plot(k, pic)
    plt.plot(k, pfin)
    plt.loglog()
    plt.grid(which='both')
    plt.savefig('pklin.png')
    plt.close()

    print(pic)
    print(pfin)
    #sys.exit(-1)

    ################################################################
    tf.reset_default_graph()
    print('ic constructed')

    #noise = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(np.float32)*1
    #data_noised = fin + noise
    #data = data_noised

    startpos = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(
        np.float32) * 1
    startpos = startpos.flatten()

    x0 = tf.placeholder(dtype=tf.float32,
                        shape=data.flatten().shape,
                        name='initlin')
    xlin = tf.placeholder(dtype=tf.float32, shape=data.shape, name='linfield')
    Rsm = tf.placeholder(tf.float32, name='smoothing')

    def recon_prototype(linearflat):
        """
        """
        linear = tf.reshape(linearflat, data.shape)
        #

        #loss = tf.reduce_sum(tf.square(linear - minimum))
        state = lpt_init(linear, a0=0.1, order=1)
        final_state = nbody(state, stages, FLAGS.nc)
        final_field = cic_paint(tf.zeros_like(linear), final_state[0])
        #final_field = pmgraph(linear)
        base = final_field

        if FLAGS.anneal:
            Rsmsq = tf.multiply(Rsm * bs / nc, Rsm * bs / nc)
            smwts = tf.exp(tf.multiply(-kmesh**2, Rsmsq))
            basek = r2c3d(base, norm=nc**3)
            basek = tf.multiply(basek, tf.cast(smwts, tf.complex64))
            base = c2r3d(basek, norm=nc**3)

        galmean = tfp.distributions.Poisson(rate=plambda * (1 + base))
        logprob = -tf.reduce_sum(galmean.log_prob(data))
        #logprob = tf.multiply(logprob, 1/nc**3, name='logprob')

        #Prior
        lineark = r2c3d(linear, norm=nc**3)
        priormesh = tf.square(tf.cast(tf.abs(lineark), tf.float32))
        prior = tf.reduce_sum(tf.multiply(priormesh, 1 / priorwt))
        #prior = tf.multiply(prior, 1/nc**3, name='prior')
        #
        loss = logprob + prior

        grad = tf.gradients(loss, linearflat)
        print(grad)
        return loss, grad[0]

    @tf.function
    def min_lbfgs():
        return tfp.optimizer.lbfgs_minimize(
            #make_val_and_grad_fn(recon_prototype),
            recon_prototype,
            initial_position=x0,
            tolerance=1e-10,
            max_iterations=FLAGS.niter)

    tfinal_field = pmgraph(xlin)

    RRs = [2.0, 1.0, 0.0]
    start0 = time.time()
    with tf.Session() as sess:

        for iR, RR in enumerate(RRs):

            start = time.time()
            results = sess.run(min_lbfgs(), {Rsm: RR, x0: startpos})
            print("\n")
            print(results)
            print("\n")
            startpos = results.position
            print(startpos)
            print("\nTime taken for %d iterations: " % FLAGS.niter,
                  time.time() - start)

            minic = startpos.reshape(data.shape)
            minfin = sess.run(tfinal_field, {xlin: minic})
            dg.saveimfig("R%d" % RR, [minic, minfin], [ic, fin], fpath + '')
            dg.save2ptfig("R%d" % RR, [minic, minfin], [ic, fin], fpath + '',
                          bs)

            np.save(fpath + 'recon-icR%d' % RR, minic)
            np.save(fpath + 'recon-finalR%d' % RR, minfin)

    #tf.reset_default_graph()
    print("\n")
    print('\nminimized\n')
    print('\nTotal time taken %d iterations: ' % (len(RRs) * FLAGS.niter),
          time.time() - start0)

    #tfic = linear_field(FLAGS.nc, FLAGS.box_size, ipklin, batch_size=1, seed=100, dtype=dtype)*0 + minimum.reshape(data_noised.shape)
    #state = lpt_init(tfic, a0=0.1, order=1)
    #final_state = nbody(state,  stages, FLAGS.nc)
    #tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])

    ##
    exit(0)
Ejemplo n.º 10
0
def main():

    startw = time.time()

    if args.nbody:
        dpath = '/project/projectdirs/m3058/chmodi/rim-data/poisson_L%04d_N%03d_T%02d_p%03/' % (
            bs, nc, nsteps, plambda * 100)
    else:
        dpath = '/project/projectdirs/m3058/chmodi/rim-data/poisson_L%04d_N%03d_LPT%d_p%03d/' % (
            bs, nc, args.lpt_order, plambda * 100)
    ic, fin, data = np.load(dpath + '%04d.npy' % 0)
    ic, fin, data = np.expand_dims(ic, 0), np.expand_dims(fin,
                                                          0), np.expand_dims(
                                                              data, 0)
    print(ic.shape, fin.shape, data.shape)

    check = pm(tf.constant(ic)).numpy()
    print(fin / check)
    #ic = np.load('../data/poisson_L%04d_N%03d/ic.npy'%(bs, nc))
    #fin = np.load('../data/poisson_L%04d_N%03d/final.npy'%(bs, nc))
    #data = np.load('../data/poisson_L%04d_N%03d/psample_%0.2f.npy'%(bs, nc, plambda))

    @tf.function
    def recon_prototype(linear, Rsm=0):
        """
        """

        linear = tf.reshape(linear, data.shape)
        #loss = tf.reduce_sum(tf.square(linear - minimum))
        final_field = pm(linear)
        base = final_field

        if anneal:
            print('\nAdd annealing graph\n')
            Rsmsq = tf.multiply(Rsm * bs / nc, Rsm * bs / nc)
            smwts = tf.exp(tf.multiply(-kmesh**2, Rsmsq))
            basek = r2c3d(base, norm=nc**3)
            basek = tf.multiply(basek, tf.cast(smwts, tf.complex64))
            base = c2r3d(basek, norm=nc**3)

        galmean = tfp.distributions.Poisson(rate=plambda * (1 + base))
        sample = galmean.sample()
        logprob = -tf.reduce_sum(galmean.log_prob(data))
        logprob = tf.multiply(logprob, 1 / nc**3, name='logprob')

        #Prior
        lineark = r2c3d(linear, norm=nc**3)
        priormesh = tf.square(tf.cast(tf.abs(lineark), tf.float32))
        prior = tf.reduce_sum(tf.multiply(priormesh, 1 / priorwt))
        prior = tf.multiply(prior, 1 / nc**3, name='prior')
        #
        loss = logprob + prior

        return loss

    #Loop it Reconstruction
    ##Reconstruction
    x0 = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(np.float32)
    linear = tf.Variable(name='linmesh',
                         shape=(1, nc, nc, nc),
                         dtype=tf.float32,
                         initial_value=x0,
                         trainable=True)

    ##
    for iR, RR in enumerate(RRs):

        @tf.function
        def val_and_grad(x):
            with tf.GradientTape() as tape:
                tape.watch(x)
                loss = recon_prototype(x, tf.constant(RR, dtype=tf.float32))
            grad = tape.gradient(loss, x)
            return loss, grad

        @tf.function
        def grad(x):
            with tf.GradientTape() as tape:
                tape.watch(x)
                loss = recon_prototype(x, tf.constant(RR, dtype=tf.float32))
            grad = tape.gradient(loss, x)
            return grad

        start = time.time()

        #
        if optimizer == 'scipy-lbfgs':

            def func(x):
                return [
                    vv.numpy().astype(np.float64)
                    for vv in val_and_grad(x=tf.constant(x, dtype=tf.float32))
                ]  #

            results = sopt.minimize(
                fun=func,
                x0=x0,
                jac=True,
                method='L-BFGS-B',
                #tol=1e-10,
                options={
                    'maxiter': niter,
                    'ftol': 1e-12,
                    'gtol': 1e-12,
                    'eps': 1e-12
                })
            #options={'maxiter':niter})
            print(results)
            minic = results.x.reshape(data.shape)

        #
        elif optimizer == 'tf2-lbfgs':

            @tf.function
            def min_lbfgs(x0):
                return tfp.optimizer.lbfgs_minimize(val_and_grad,
                                                    initial_position=x0,
                                                    tolerance=1e-10,
                                                    max_iterations=niter)

            results = min_lbfgs(x0.flatten())
            print(results)
            minic = results.position.numpy().reshape(data.shape)

        #
        elif optimizer == 'adam':

            opt = tf.keras.optimizers.Adam(learning_rate=lr)
            for i in range(niter):
                grads = grad([linear])
                opt.apply_gradients(zip(grads, [linear]))
            minic = linear.numpy().reshape(data.shape)

        #
        print('\nminimized\n')
        print("Time taken for maxiter %d : " % niter, time.time() - start)

        minfin = pm(tf.constant(minic, dtype=tf.float32)).numpy()
        dg.saveimfig("-R%d" % RR, [minic, minfin], [ic, fin], fpath + '')
        dg.save2ptfig("-R%d" % RR, [minic, minfin], [ic, fin], fpath + '', bs)
        ###
        x0 = minic

    exit(0)
Ejemplo n.º 11
0
def main(_):

    dtype = tf.float32
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)

    print(mesh_shape)

    #layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    #mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)]
    layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"),
                    ("ny", "col"), ("ty", "row"), ("tz", "col"),
                    ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"),
                    ("ny_block", "col")]

    # Resolve the cluster from SLURM environment
    cluster = tf.distribute.cluster_resolver.SlurmClusterResolver(
        {"mesh": mesh_shape.size // FLAGS.gpus_per_task},
        port_base=8822,
        gpus_per_node=FLAGS.gpus_per_node,
        gpus_per_task=FLAGS.gpus_per_task,
        tasks_per_node=FLAGS.tasks_per_node)
    cluster_spec = cluster.cluster_spec()
    print(cluster_spec)
    # Create a server for all mesh members
    server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id)
    print(server)

    if cluster.task_id > 0:
        server.join()

    # Otherwise we are the main task, let's define the devices
    devices = [
        "/job:mesh/task:%d/device:GPU:%d" % (i, j)
        for i in range(cluster_spec.num_tasks("mesh"))
        for j in range(FLAGS.gpus_per_task)
    ]
    print("List of devices", devices)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, devices)

    ##Begin here
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)

    tf.reset_default_graph()
    # Run normal flowpm to generate data
    try:
        ic, fin = np.load(fpath + 'ic.npy'), np.load(fpath + 'final.npy')
        print('Data loaded')
    except Exception as e:
        print('Exception occured', e)
        tfic = linear_field(FLAGS.nc,
                            FLAGS.box_size,
                            ipklin,
                            batch_size=1,
                            seed=100,
                            dtype=dtype)
        if FLAGS.nbody:
            state = lpt_init(tfic, a0=0.1, order=1)
            final_state = nbody(state, stages, FLAGS.nc)
        else:
            final_state = lpt_init(tfic, a0=stages[-1], order=1)
        tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
        with tf.Session(server.target) as sess:
            ic, fin = sess.run([tfic, tfinal_field])
        np.save(fpath + 'ic', ic)
        np.save(fpath + 'final', fin)

    tf.reset_default_graph()
    print('ic constructed')

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    initial_conditions, final_field, loss, var_grads, update_op, linear_op, input_field, lr, R0 = recon_prototype(
        mesh, fin, nc=FLAGS.nc, batch_size=FLAGS.batch_size, dtype=dtype)

    # Lower mesh computation

    start = time.time()
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)
    end = time.time()
    print('\n Time for lowering : %f \n' % (end - start))

    tf_initc = lowering.export_to_tf_tensor(initial_conditions)
    tf_final = lowering.export_to_tf_tensor(final_field)
    tf_grads = lowering.export_to_tf_tensor(var_grads[0])
    tf_linear_op = lowering.lowered_operation(linear_op)
    tf_update_ops = lowering.lowered_operation(update_op)
    n_block_x, n_block_y, n_block_z = FLAGS.nx, FLAGS.ny, 1
    nc = FLAGS.nc
    ic_hrshape = ic.reshape([
        FLAGS.batch_size, n_block_x, nc // n_block_x, n_block_y,
        nc // n_block_y, n_block_z, nc // n_block_z
    ])
    ic_hrshape = np.transpose(ic_hrshape, [0, 1, 3, 5, 2, 4, 6])
    with tf.Session(server.target) as sess:

        #ic_check, fin_check = sess.run([tf_initc, tf_final])
        sess.run(tf_linear_op, feed_dict={input_field: ic_hrshape})
        ic_check, fin_check = sess.run([tf_initc, tf_final])
        dg.saveimfig('-check', [ic_check, fin_check], [ic, fin], fpath)
        dg.save2ptfig('-check', [ic_check, fin_check], [ic, fin], fpath, bs)

        sess.run(tf_linear_op,
                 feed_dict={
                     input_field:
                     np.random.normal(size=ic.size).reshape(ic_hrshape.shape)
                 })
        ic0, fin0 = sess.run([tf_initc, tf_final])
        dg.saveimfig('-init', [ic0, fin0], [ic, fin], fpath)
        start = time.time()

        niter = 5
        iiter = 0
        start0 = time.time()
        RRs = [4, 2, 1, 0.5, 0]
        lrs = np.array([0.2, 0.15, 0.1, 0.1, 0.1])
        #lrs = [0.1, 0.05, 0.01, 0.005, 0.001]

        for iR, zlR in enumerate(zip(RRs, lrs)):
            RR, lR = zlR
            #for ff in [fpath + '/figs-R%02d'%(10*RR)]:
            for ff in [fpath + '/figsiter']:
                try:
                    os.makedirs(ff)
                except Exception as e:
                    print(e)
            for i in range(301):
                if (i % niter == 0):
                    end = time.time()
                    print('Iter : ', i)
                    print('Time taken for %d iterations: ' % niter,
                          end - start)
                    start = end
                    ##
                    ic1, fin1 = sess.run([tf_initc, tf_final])

                    #dg.saveimfig(i, [ic1, fin1], [ic, fin], fpath+'/figs-R%02d'%(10*RR))
                    #dg.save2ptfig(i, [ic1, fin1], [ic, fin], fpath+'/figs-R%02d'%(10*RR), bs)
                    dg.saveimfig2x2(iiter, [ic1, fin1], [ic, fin],
                                    fpath + '/figsiter')
                    #
                sess.run(tf_update_ops, {lr: lR, R0: RR})
                iiter += 1

            dg.saveimfig(i * (iR + 1), [ic1, fin1], [ic, fin], fpath + '/figs')
            dg.save2ptfig(i * (iR + 1), [ic1, fin1], [ic, fin],
                          fpath + '/figs', bs)

        ic1, fin1 = sess.run([tf_initc, tf_final])
        print('Total time taken for %d iterations is : ' % iiter,
              time.time() - start0)

    dg.saveimfig(i, [ic1, fin1], [ic, fin], fpath)
    dg.save2ptfig(i, [ic1, fin1], [ic, fin], fpath, bs)

    np.save(fpath + 'ic_recon', ic1)
    np.save(fpath + 'final_recon', fin1)
    print('Total wallclock time is : ', time.time() - start0)

    ##
    exit(0)
Ejemplo n.º 12
0
def main(_):

    infield = True
    dtype = tf.float32
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    nc, bs = FLAGS.nc, FLAGS.box_size
    a0, a, nsteps = FLAGS.a0, FLAGS.af, FLAGS.nsteps
    stages = np.linspace(a0, a, nsteps, endpoint=True)
    numd = 1e-3

    startw = time.time()

    print(mesh_shape)

    #layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    #mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)]
    layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"),
                    ("ny", "col"), ("ty", "row"), ("tz", "col"),
                    ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"),
                    ("ny_block", "col")]

    # Resolve the cluster from SLURM environment
    cluster = tf.distribute.cluster_resolver.SlurmClusterResolver(
        {"mesh": mesh_shape.size // FLAGS.gpus_per_task},
        port_base=8822,
        gpus_per_node=FLAGS.gpus_per_node,
        gpus_per_task=FLAGS.gpus_per_task,
        tasks_per_node=FLAGS.tasks_per_node)
    cluster_spec = cluster.cluster_spec()
    print(cluster_spec)
    # Create a server for all mesh members
    server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id)
    print(server)

    if cluster.task_id > 0:
        server.join()

    # Otherwise we are the main task, let's define the devices
    devices = [
        "/job:mesh/task:%d/device:GPU:%d" % (i, j)
        for i in range(cluster_spec.num_tasks("mesh"))
        for j in range(FLAGS.gpus_per_task)
    ]
    print("List of devices", devices)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, devices)

    ##Begin here
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)

    final = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0128_S0100_05step/mesh/d/'
    )
    ic = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0128_S0100_05step/mesh/s/'
    )

    pypath = '/global/cscratch1/sd/chmodi/cosmo4d/output/version2/L0400_N0128_05step-fof/lhd_S0100/n10/opt_s999_iM12-sm3v25off/meshes/'
    fin = tools.readbigfile(pypath + 'decic//')

    hpos = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0512_S0100_40step/FOF/PeakPosition//'
    )[1:int(bs**3 * numd)]
    hmass = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0512_S0100_40step/FOF/Mass//'
    )[1:int(bs**3 * numd)].flatten()

    #meshpos = tools.paintcic(hpos, bs, nc)
    meshmass = tools.paintcic(hpos, bs, nc, hmass.flatten() * 1e10)
    data = meshmass
    kv = tools.fftk([nc, nc, nc], bs, symmetric=True, dtype=np.float32)
    datasm = tools.fingauss(data, kv, 3, np.pi * nc / bs)
    ic, data = np.expand_dims(ic, 0), np.expand_dims(data,
                                                     0).astype(np.float32)
    datasm = np.expand_dims(datasm, 0).astype(np.float32)
    print("Min in data : %0.4e" % datasm.min())

    ic, data = np.expand_dims(ic, 0), np.expand_dims(data,
                                                     0).astype(np.float32)
    np.save(fpath + 'ic', ic)
    np.save(fpath + 'data', data)

    ####################################################
    tf.reset_default_graph()
    print('ic constructed')

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    initial_conditions, data_field, loss, var_grads, update_op, linear_op, input_field, lr, R0, M0, width, chisq, prior, tf_off, tf_istd = recon_prototype(
        mesh, datasm, nc=FLAGS.nc, batch_size=FLAGS.batch_size, dtype=dtype)

    # Lower mesh computation

    start = time.time()
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)
    end = time.time()
    print('\n Time for lowering : %f \n' % (end - start))

    tf_initc = lowering.export_to_tf_tensor(initial_conditions)
    tf_data = lowering.export_to_tf_tensor(data_field)
    tf_chisq = lowering.export_to_tf_tensor(chisq)
    tf_prior = lowering.export_to_tf_tensor(prior)
    tf_grads = lowering.export_to_tf_tensor(var_grads[0])
    #tf_lr = lowering.export_to_tf_tensor(lr)
    tf_linear_op = lowering.lowered_operation(linear_op)
    tf_update_ops = lowering.lowered_operation(update_op)
    n_block_x, n_block_y, n_block_z = FLAGS.nx, FLAGS.ny, 1
    nc = FLAGS.nc

    with tf.Session(server.target) as sess:

        start = time.time()
        sess.run(tf_linear_op, feed_dict={input_field: ic})
        ic_check, data_check = sess.run([tf_initc, tf_data], {width: 3})

        dg.saveimfig('-check', [ic_check, data_check], [ic, data],
                     fpath + '/figs/')
        dg.save2ptfig('-check', [ic_check, data_check], [ic, data],
                      fpath + '/figs/', bs)
        print('Total time taken for mesh thingy is : ', time.time() - start)

        sess.run(tf_linear_op,
                 feed_dict={
                     input_field:
                     np.random.normal(size=ic.size).reshape(ic.shape)
                 })
        ic0, data0 = sess.run([tf_initc, tf_data], {width: 3})
        dg.saveimfig('-init', [ic0, data0], [ic, data], fpath)
        start = time.time()

        titer = 20
        niter = 101
        iiter = 0

        start0 = time.time()
        RRs = [4, 2, 1, 0.5, 0]
        wws = [1, 2, 3]
        lrs = np.array([0.1, 0.1, 0.1, 0.1, 0.1]) * 2
        #lrs = [0.1, 0.05, 0.01, 0.005, 0.001]

        readin = True
        mm0, ww0, RR0 = 1e12, 3, 0.5
        if readin:
            icread = np.load(fpath + '/figs-M%02d-R%02d-w%01d/ic_recon.npy' %
                             (np.log10(mm0), 10 * RR0, ww0))
            sess.run(tf_linear_op, feed_dict={input_field: icread})

        for mm in [1e12, 1e11]:
            print('Fraction of points above 1 for mm = %0.2e: ' % mm,
                  (datasm > mm).sum() / datasm.size)
            noisefile = '/project/projectdirs/m3058/chmodi/cosmo4d/train/L0400_N0128_05step-n10/width_3/Wts_30_10_1/r1rf1/hlim-13_nreg-43_batch-5/eluWts-10_5_1/blim-20_nreg-23_batch-100/hist_M%d_na.txt' % (
                np.log10(mm) * 10)
            offset, ivar = setnoise(datasm, noisefile, noisevar=0.25)
            for iR, zlR in enumerate(zip(RRs, lrs)):
                RR, lR = zlR
                for ww in wws:
                    for ff in [
                            fpath + '/figs-M%02d-R%02d-w%01d' %
                        (np.log10(mm), 10 * RR, ww)
                    ]:
                        try:
                            os.makedirs(ff)
                        except Exception as e:
                            print(e)
                    if readin:
                        if mm > mm0: continue
                        elif mm == mm0 and RR > RR0:
                            print(RR, RR0, RRs)
                            continue
                        elif RR == RR0 and ww <= ww0:
                            print(ww, ww0, wws)
                            continue
                        else:
                            print('Starting from %0.2e' % mm, RR, ww)
                    print('Do for %0.2e' % mm, RR, ww)

                    for i in range(niters[iR]):
                        iiter += 1
                        sess.run(
                            tf_update_ops, {
                                lr: lR,
                                M0: mm,
                                R0: RR,
                                width: ww,
                                tf_off: offset,
                                tf_istd: ivar**0.5
                            })
                        if (i % titer == 0):
                            end = time.time()
                            print('Iter : ', i)
                            print('Time taken for %d iterations: ' % titer,
                                  end - start)
                            start = end

                            ##
                            ic1, data1, cc, pp = sess.run(
                                [tf_initc, tf_data, tf_chisq, tf_prior], {
                                    M0: mm,
                                    R0: RR,
                                    width: ww,
                                    tf_off: offset,
                                    tf_istd: ivar**0.5
                                })
                            print('Chisq and prior are : ', cc, pp)

                            dg.saveimfig(i, [ic1, data1], [ic, data], ff)
                            dg.save2ptfig(i, [ic1, data1], [ic, data], ff, bs)

                    ic1, data1 = sess.run([tf_initc, tf_data], {width: ww})
                    np.save(ff + '/ic_recon', ic1)
                    np.save(ff + '/data_recon', data1)
                    dg.saveimfig(iiter, [ic1, data1], [ic, data],
                                 fpath + '/figs')
                    dg.save2ptfig(iiter, [ic1, data1], [ic, data],
                                  fpath + '/figs', bs)

            wws = [3]
            RRs = [0]
            niters = [201, 101, 201]
            lrs = np.array([0.1, 0.1, 0.1])

        ic1, data1 = sess.run([tf_initc, tf_data], {width: 3})
        print('Total time taken for %d iterations is : ' % iiter,
              time.time() - start0)

    dg.saveimfig('', [ic1, data1], [ic, data], fpath)
    dg.save2ptfig('', [ic1, data1], [ic, data], fpath, bs)

    np.save(fpath + 'ic_recon', ic1)
    np.save(fpath + 'data_recon', data1)
    print('Total wallclock time is : ', time.time() - start0)

    ##
    exit(0)
Ejemplo n.º 13
0
def main(_):

    infield = True
    dtype = tf.float32
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    nc, bs = FLAGS.nc, FLAGS.box_size
    a0, a, nsteps = FLAGS.a0, FLAGS.af, FLAGS.nsteps
    stages = np.linspace(a0, a, nsteps, endpoint=True)
    numd = 1e-3

    ##Begin here
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)

    #pypath = '/global/cscratch1/sd/chmodi/cosmo4d/output/version2/L0400_N0128_05step-fof/lhd_S0100/n10/opt_s999_iM12-sm3v25off/meshes/'
    final = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0128_S0100_05step/mesh/d/'
    )
    ic = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0128_S0100_05step/mesh/s/'
    )
    fpos = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0128_S0100_05step/dynamic/1/Position/'
    )
    aa = 1
    zz = 1 / aa - 1
    rsdfactor = float(100 / (aa**2 * cosmo.H(zz).value**1))
    print('\nRsdfactor used is : ', rsdfactor)

    hpos = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0512_S0100_40step/FOF/PeakPosition//'
    )[1:int(bs**3 * numd)]
    hvel = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0512_S0100_40step/FOF/CMVelocity//'
    )[1:int(bs**3 * numd)]
    rsdpos = hpos + hvel * rsdfactor * np.array([0, 0, 1])
    print('Effective displacement : ', (hvel[:, -1] * rsdfactor).std())
    hmass = tools.readbigfile(
        '/project/projectdirs/m3058/chmodi/cosmo4d/data/L0400_N0512_S0100_40step/FOF/Mass//'
    )[1:int(bs**3 * numd)].flatten()

    meshpos = tools.paintcic(rsdpos, bs, nc)
    meshmass = tools.paintcic(rsdpos, bs, nc, hmass.flatten() * 1e10)
    data = meshmass
    kv = tools.fftk([nc, nc, nc], bs, symmetric=True, dtype=np.float32)
    datasm = tools.fingauss(data, kv, 3, np.pi * nc / bs)
    ic, data = np.expand_dims(ic, 0), np.expand_dims(data,
                                                     0).astype(np.float32)
    datasm = np.expand_dims(datasm, 0).astype(np.float32)
    print("Min in data : %0.4e" % datasm.min())

    #

    ####################################################

    stdinit = srecon.standardinit(bs, nc, meshpos, hpos, final, R=8)
    recon_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                             model_dir=fpath)

    def predict_input_fn(data=data,
                         M0=0.,
                         w=3.,
                         R0=0.,
                         off=None,
                         istd=None,
                         x0=None):
        features = {}
        features['datasm'] = data
        features['rsdfactor'] = rsdfactor
        features['M0'] = M0
        features['w'] = w
        features['R0'] = R0
        features['off'] = off
        features['istd'] = istd
        features['x0'] = x0
        return features, None

    eval_results = recon_estimator.predict(
        input_fn=lambda: predict_input_fn(x0=ic), yield_single_examples=False)

    for i, pred in enumerate(eval_results):
        if i > 0: break

    suff = '-model'
    dg.saveimfig(suff, [pred['ic'], pred['model']], [ic, data],
                 fpath + '/figs/')
    dg.save2ptfig(suff, [pred['ic'], pred['model']], [ic, data],
                  fpath + '/figs/', bs)
    np.save(fpath + '/reconmeshes/ic_true' + suff, pred['ic'])
    np.save(fpath + '/reconmeshes/fin_true' + suff, pred['final'])
    np.save(fpath + '/reconmeshes/model_true' + suff, pred['model'])

    #
    randominit = np.random.normal(size=data.size).reshape(data.shape)
    #eval_results = recon_estimator.predict(input_fn=lambda : predict_input_fn(x0 = np.expand_dims(stdinit, 0)), yield_single_examples=False)
    eval_results = recon_estimator.predict(
        input_fn=lambda: predict_input_fn(x0=randominit),
        yield_single_examples=False)

    for i, pred in enumerate(eval_results):
        if i > 0: break

    suff = '-init'
    dg.saveimfig(suff, [pred['ic'], pred['model']], [ic, data],
                 fpath + '/figs/')
    dg.save2ptfig(suff, [pred['ic'], pred['model']], [ic, data],
                  fpath + '/figs/', bs)
    np.save(fpath + '/reconmeshes/ic_init' + suff, pred['ic'])
    np.save(fpath + '/reconmeshes/fin_init' + suff, pred['final'])
    np.save(fpath + '/reconmeshes/model_init' + suff, pred['model'])

    #
    # Train and evaluate model.
    mms = [1e12, 1e11]
    wws = [1., 2., 3.]
    RRs = [4., 2., 1., 0.5, 0.]
    niter = 100
    iiter = 0

    for mm in mms:

        noisefile = '/project/projectdirs/m3058/chmodi/cosmo4d/train/L0400_N0128_05step-n10/width_3/Wts_30_10_1/r1rf1/hlim-13_nreg-43_batch-5/eluWts-10_5_1/blim-20_nreg-23_batch-100/hist_M%d_na.txt' % (
            np.log10(mm) * 10)
        offset, ivar = setnoise(datasm, noisefile, noisevar=0.25)
        istd = ivar**0.5
        if not FLAGS.offset: offset = None
        if not FLAGS.istd: istd = None

        for R0 in RRs:

            for ww in wws:

                print('\nFor iteration %d\n' % iiter)
                print('With mm=%0.2e, R0=%0.2f, ww=%d \n' % (mm, R0, ww))

                def train_input_fn():
                    features = {}
                    features['datasm'] = datasm
                    features['rsdfactor'] = rsdfactor
                    features['M0'] = mm
                    features['w'] = ww
                    features['R0'] = R0
                    features['off'] = offset
                    features['istd'] = istd
                    features['x0'] = np.expand_dims(
                        stdinit, 0
                    )  #np.random.normal(size=datasm.size).reshape(datasm.shape)
                    features['lr'] = 0.01
                    return features, None

                recon_estimator.train(input_fn=train_input_fn,
                                      max_steps=iiter + niter)
                eval_results = recon_estimator.predict(
                    input_fn=predict_input_fn, yield_single_examples=False)

                for i, pred in enumerate(eval_results):
                    if i > 0: break

                iiter += niter  #
                suff = '-%d-M%d-R%d-w%d' % (iiter, np.log10(mm), R0, ww)
                dg.saveimfig(suff, [pred['ic'], pred['model']], [ic, data],
                             fpath + '/figs/')
                dg.save2ptfig(suff, [pred['ic'], pred['model']], [ic, data],
                              fpath + '/figs/', bs)
                suff = '-M%d-R%d-w%d' % (np.log10(mm), R0, ww)
                np.save(fpath + '/reconmeshes/ic' + suff, pred['ic'])
                np.save(fpath + '/reconmeshes/fin' + suff, pred['final'])
                np.save(fpath + '/reconmeshes/model' + suff, pred['model'])

        RRs = [1., 0.5, 0.]
        wws = [3.]
        niter = 200

    sys.exit(0)

    ##
    exit(0)
Ejemplo n.º 14
0
def main(_):

    dtype=tf.float32

    startw = time.time()

    tf.random.set_random_seed(100)
    np.random.seed(100)

    
    # Compute a few things first, using simple tensorflow
    a0=FLAGS.a0
    a=FLAGS.af
    nsteps=FLAGS.nsteps
    bs, nc = FLAGS.box_size, FLAGS.nc
    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    tf.reset_default_graph()
    # Run normal flowpm to generate data
    try:
        ic, fin = np.load(fpath + 'ic.npy'), np.load(fpath + 'final.npy')
        print('Data loaded')
    except Exception as e:
        print('Exception occured', e)
        tfic = linear_field(FLAGS.nc, FLAGS.box_size, ipklin, batch_size=1, seed=100, dtype=dtype)
        if FLAGS.nbody:
            state = lpt_init(tfic, a0=0.1, order=1)
            final_state = nbody(state,  stages, FLAGS.nc)
        else:
            final_state = lpt_init(tfic, a0=stages[-1], order=1)
        tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
        with tf.Session() as sess:
            ic, fin  = sess.run([tfic, tfinal_field])
        np.save(fpath + 'ic', ic)
        np.save(fpath + 'final', fin)


    ################################################################
    tf.reset_default_graph()
    print('ic constructed')

    noise = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(np.float32)
    data_noised = fin + noise
    data = data_noised

    minimum = data.copy()
    start = noise.copy().flatten().astype(np.float32)



    Rsm = tf.placeholder(tf.float32, name='smoothing')
    def recon_prototype(linear):
        """
        """
        
        linear = tf.reshape(linear, minimum.shape)
        #loss = tf.reduce_sum(tf.square(linear - minimum)) 

        state = lpt_init(linear, a0=0.1, order=1)
        final_state = nbody(state,  stages, FLAGS.nc)
        final_field = cic_paint(tf.zeros_like(linear), final_state[0])

        residual = final_field - data.astype(np.float32)
        base = residual
##        Rsmsq = tf.multiply(Rsm*bs/nc, Rsm*bs/nc)
##        smwts = tf.exp(tf.multiply(-kmesh**2, Rsmsq))
##        basek = r2c3d(base, norm=nc**3)
##        basek = tf.multiply(basek, tf.cast(smwts, tf.complex64))
##        base = c2r3d(basek, norm=nc**3)   
####    #
        chisq = tf.multiply(base, base)
        chisq = tf.reduce_sum(chisq)
        chisq = tf.multiply(chisq, 1/nc**3, name='chisq')

        #Prior
        lineark = r2c3d(linear, norm=nc**3)
        priormesh = tf.square(tf.cast(tf.abs(lineark), tf.float32))
        prior = tf.reduce_sum(tf.multiply(priormesh, 1/priorwt))
        prior = tf.multiply(prior, 1/nc**3, name='prior')
        #
        loss = chisq + prior

        return loss


    @tf.function
    def min_lbfgs():
      return tfp.optimizer.lbfgs_minimize(
          make_val_and_grad_fn(recon_prototype),
          initial_position=tf.constant(start),
          tolerance=1e-5,
          max_iterations=50)

    with tf.Session() as sess:
        #results = sess.run(min_lbfgs(), {Rsm:4})
        results = sess.run(min_lbfgs())
    print(results)
    minimum = results.position
    print(minimum)

    tf.reset_default_graph()
    print('\nminimized\n')

    tfic = linear_field(FLAGS.nc, FLAGS.box_size, ipklin, batch_size=1, seed=100, dtype=dtype)*0 + minimum.reshape(data_noised.shape)
    state = lpt_init(tfic, a0=0.1, order=1)
    final_state = nbody(state,  stages, FLAGS.nc)
    tfinal_field = cic_paint(tf.zeros_like(tfic), final_state[0])
    with tf.Session() as sess:
        minic, minfin  = sess.run([tfic, tfinal_field])

    dg.saveimfig(0, [minic, minfin], [ic, fin], fpath+'')
    dg.save2ptfig(0, [minic, minfin], [ic, fin], fpath+'', bs)
    
    
##
    exit(0)
Ejemplo n.º 15
0
def main():

    startw = time.time()

    # Run normal flowpm to generate data
    try:
        ic, fin = np.load(fpath + 'ic.npy'), np.load(fpath + 'final.npy')
        print('Data loaded')
    except Exception as e:
        print('Exception occured', e)
        tfic = linear_field(nc,
                            bs,
                            ipklin,
                            batch_size=1,
                            seed=100,
                            dtype=dtype)
        tfinal_field = pm(tfic)
        ic, fin = tfic.numpy(), tfinal_field.numpy()
        np.save(fpath + 'ic', ic)
        np.save(fpath + 'final', fin)

    print('\ndata constructed\n')

    noise = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(np.float32)
    data_noised = fin + noise
    data = data_noised

    @tf.function
    def recon_prototype(linear, Rsm=0):
        """
        """

        linear = tf.reshape(linear, data.shape)
        #loss = tf.reduce_sum(tf.square(linear - minimum))
        final_field = pm(linear)

        residual = final_field - data.astype(np.float32)
        base = residual

        if anneal:
            print("\nAdd annealing section to graph\n")
            Rsmsq = tf.multiply(Rsm * bs / nc, Rsm * bs / nc)
            smwts = tf.exp(tf.multiply(-kmesh**2, Rsmsq))
            basek = r2c3d(base, norm=nc**3)
            basek = tf.multiply(basek, tf.cast(smwts, tf.complex64))
            base = c2r3d(basek, norm=nc**3)

    ###
        chisq = tf.multiply(base, base)
        chisq = tf.reduce_sum(chisq)
        chisq = tf.multiply(chisq, 1 / nc**3, name='chisq')

        #Prior
        lineark = r2c3d(linear, norm=nc**3)
        priormesh = tf.square(tf.cast(tf.abs(lineark), tf.float32))
        prior = tf.reduce_sum(tf.multiply(priormesh, 1 / priorwt))
        prior = tf.multiply(prior, 1 / nc**3, name='prior')
        #
        loss = chisq + prior

        return loss
#

    @tf.function
    def val_and_grad(x, RR):
        with tf.GradientTape() as tape:
            tape.watch(x)
            loss = recon_prototype(x, RR)
        grad = tape.gradient(loss, x)
        return loss, grad

    @tf.function
    def min_lbfgs(x0, RR):
        return tfp.optimizer.lbfgs_minimize(
            lambda x: val_and_grad(x, tf.constant(RR, dtype=tf.float32)),
            initial_position=x0,
            tolerance=1e-10,
            max_iterations=200)


#
#    def make_val_and_grad_fn(value_fn, R):
#        @functools.wraps(value_fn)
#        def val_and_grad(x):
#            return tfp.math.value_and_gradient(value_fn, x)
#        return val_and_grad
#
#    @tf.function
#    def min_lbfgs(x0, RR):
#        return tfp.optimizer.lbfgs_minimize(
#            make_val_and_grad_fn(recon_prototype),
#            recon_prototype,
#            initial_position=x0,
#            tolerance=1e-10,
#            max_iterations=100)
#
##Reconstruction

    x0 = np.random.normal(0, 1, nc**3).reshape(fin.shape).astype(
        np.float32).flatten()

    RRs = [2, 1, 0.5, 0]
    for iR, RR in enumerate(RRs):

        results = min_lbfgs(x0, RR)
        #results = sopt.minimize(fun=func, x0=x0, args = RR, jac=True, method='L-BFGS-B',
        #                        options={'maxiter':200, 'ftol': 2.220446049250313e-09, 'gtol': 1e-10,})

        print(results)
        ###
        minic = results.position.numpy().reshape(data.shape)
        print(minic.shape)
        print('\nminimized\n')
        minfin = pm(tf.constant(minic, dtype=tf.float32)).numpy()
        dg.saveimfig("-R%d" % RR, [minic, minfin], [ic, fin], fpath + '')
        dg.save2ptfig("-R%d" % RR, [minic, minfin], [ic, fin], fpath + '', bs)