Beispiel #1
0
params = {}
params['input_size'] = args.input_size
params['cell_size'] = args.cell_size
params['strides'] = 2
params['middle_size'] = args.input_size // params['strides']  #lets divide by strides
params['cell_kernel_size'] = 5
params['input_kernel_size'] = 5
params['middle_kernel_size'] = 5
params['output_kernel_size'] = 5
params['rim_iter'] = args.rim_iter
params['input_activation'] = 'tanh'
params['output_activation'] = 'linear'
params['nc'] = nc


rim = build_rim_parallel_single(params)
grad_fn = datamodel.recon_grad
adam = myAdam(params['rim_iter'])
adam10 = myAdam(10*params['rim_iter'])
optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
#fid_recon = Recon_Poisson(nc, bs, plambda=plambda, a0=a0, af=af, nsteps=nsteps, nbody=args.nbody, lpt_order=args.lpt_order, anneal=True)

#################################################
suffpath = '_p%03d_single'%(100*plambda) + args.suffix
if args.nbody: ofolder = './models/poisson_L%04d_N%03d_T%02d%s/'%(bs, nc, nsteps, suffpath)
else: ofolder = './models/poisson_L%04d_N%03d_LPT%d%s/'%(bs, nc, args.lpt_order, suffpath)
try: os.makedirs(ofolder)
except Exception as e: print(e)


GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print("GLOBAL_BATCH_SIZE : ", GLOBAL_BATCH_SIZE)
EPOCHS = params['epoch']
train_dataset = tf.data.Dataset.from_tensor_slices((traindata[:, 0], traindata[:, 1])).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((testdata[:, 0], testdata[:, 1])).batch(strategy.num_replicas_in_sync) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")


with strategy.scope():
    if args.parallel: rim = build_rim_parallel_single(params)
    else: rim = build_rim_split_single(params)
    grad_fn = recon_grad
    b1, b2, errormesh = setupbias()
    bias = tf.constant([b1, b2], dtype=tf.float32)
    print(bias)
    grad_params = [bias, errormesh]

    def get_opt(lr):
        return  tf.keras.optimizers.Adam(learning_rate=lr)
    optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
    checkpoint = tf.train.Checkpoint(model=rim)
    #


def main():
    """
    Model function for the CosmicRIM.
    """

    rim = build_rim_parallel_single(params)
    grad_fn = recon_grad
    #

#
#    train_dataset = tf.data.Dataset.range(args.batch_in_epoch)
#    train_dataset = train_dataset.map(pm_data)
#    # dset = dset.apply(tf.data.experimental.unbatch())
#    train_dataset = train_dataset.prefetch(-1)
#    test_dataset = tf.data.Dataset.range(1).map(pm_data_test).prefetch(-1)
#
    traindata, testdata = get_data()
    idx = np.random.randint(0, traindata.shape[0], args.sims_in_loop)
    xx, yy = traindata[idx, 0].astype(np.float32), traindata[idx, -1].astype(np.float32), 
    x_init = np.random.normal(size=xx.size).reshape(xx.shape).astype(np.float32)
    x_pred = rim(tf.constant(x_init), tf.constant(yy), grad_fn, tf.constant(xx))

    

    #
    # @tf.function
    def rim_train(x_true, x_init, y):

        gradients = [0.]*len(rim.trainable_variables)
        n = args.sims_in_loop
        for i in range(args.batch_size//n):
            #print(i, n*i, n*i+n)
            with tf.GradientTape() as tape:
                a, b, c = x_init[n*i:n*i+n], y[n*i:n*i+n],  x_true[n*i:n*i+n]
                #print(a.shape, b.shape, c.shape)
                #loss = loss + rim(x_init[i:i+1], y[i:i+1], grad_fn, x_true[i:i+1])[1]
                loss =  rim(tf.constant(a), tf.constant(b), grad_fn, tf.constant(c))[1]
            grads = tape.gradient(loss, rim.trainable_variables)
            #print(len(grads), type(grads))
            for j in range(len(grads)):
                gradients[j] = gradients[j] + grads[j] / (args.batch_size//n)
        return loss, gradients



    ##Train and save
    piter, testiter  = 10, 50
    losses = []
    #lrs = [0.001, 0.0005, 0.0001]
    #liters = [101, 501, 2001]
    lrs = [ 0.0005, 0.0001]
    liters = [1001, 2001]
    trainiter = 0 
    start = time.time()
    x_test, y_test = None, None

    for il in range(len(lrs)):
        print('Learning rate = %0.3e'%lrs[il])
        opt = tf.keras.optimizers.Adam(learning_rate=lrs[il])

        for i in range(liters[il]):
            idx = np.random.randint(0, traindata.shape[0], args.batch_size)
            xx, yy = traindata[idx, 0].astype(np.float32), traindata[idx, -1].astype(np.float32), 
            x_init = np.random.normal(size=xx.size).reshape(xx.shape).astype(np.float32)
            #x_init = (yy - (yy.max() - yy.min())/2.)/yy.std() + np.random.normal(size=xx.size).reshape(xx.shape).astype(np.float32)
            

            loss, gradients = rim_train(x_true=tf.constant(xx), 
                                    x_init=tf.constant(x_init), 
                                    y=tf.constant(yy))

            losses.append(loss.numpy())    
            opt.apply_gradients(zip(gradients, rim.trainable_variables))

            if i%piter == 0: 
                print("Time taken for %d iterations : "%piter, time.time() - start)
                print("Loss at iteration %d : "%i, losses[-1])
                start = time.time()
            if i%testiter == 0: 
                plt.plot(losses)
                plt.savefig(ofolder + 'losses.png')
                plt.close()

                ##check 2pt and comapre to Adam
                #idx = np.random.randint(0, testdata.shape[0], 1)
                #xx, yy = testdata[idx, 0].astype(np.float32), testdata[idx, 1].astype(np.float32), 
                if x_test is None:
                    idx = np.random.randint(0, testdata.shape[0], 1)
                    x_test, y_test = testdata[idx, 0].astype(np.float32), testdata[idx, -1].astype(np.float32), 
                    pred_adam = adam(tf.constant(x_init), tf.constant(y_test), grad_fn)
                    pred_adam = [pred_adam[0].numpy(), pm(pred_adam)[0].numpy()]
                    pred_adam10 = adam10(tf.constant(x_init), tf.constant(y_test), grad_fn)
                    pred_adam10 = [pred_adam10[0].numpy(), pm(pred_adam10)[0].numpy()]
                    minic, minfin = fid_recon.reconstruct(tf.constant(y_test), RRs=[1.0, 0.0], niter=args.rim_iter*10, lr=0.1)
                    compares =  [pred_adam, pred_adam10, [minic[0], minfin[0]]]
                    print('Test set generated')

                x_init = np.random.normal(size=x_test.size).reshape(x_test.shape).astype(np.float32)
                #x_init = (y_test - (y_test.max() - y_test.min())/2.)/y_test.std() + np.random.normal(size=x_test.size).reshape(x_test.shape).astype(np.float32)
                pred, _ = rim(tf.constant(x_init), tf.constant(y_test), grad_fn, tf.constant(x_test))
                check_im(x_test[0], x_init[0], pred.numpy()[0], fname=ofolder + 'rim-im-%04d.png'%trainiter)
                check_im(y_test[0], x_init[0], gal_sample(pm(pred)).numpy()[0], fname=ofolder + 'rim-fin-%04d.png'%trainiter)
                check_2pt(x_test, y_test, rim, grad_fn, compares, fname=ofolder + 'rim-2pt-%04d.png'%trainiter)

                rim.save_weights(ofolder + '/%d'%trainiter)

            trainiter  += 1