return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
    return strategy.run(test_step, args=(dataset_inputs,))


###
#Training


losses = []    


adam = myAdam(params['rim_iter'])
adam10 = myAdam(10*params['rim_iter'])

if args.parallel: suffpath = '_halo_parallel_single' + args.suffix
else: suffpath = '_halo_split_single' + args.suffix
if args.nbody: ofolder = './models/L%04d_N%03d_T%02d%s/'%(bs, nc, nsteps, suffpath)
else: ofolder = './models/L%04d_N%03d_LPT%d%s/'%(bs, nc, args.lpt_order, suffpath)
try: os.makedirs(ofolder)
except Exception as e: print(e)


for x in test_dist_dataset:
    print('Testing')
    print(len(x), x[0].values[0].shape)
    a, b, c, d = distributed_test_step(x)
    print(a.values[0].shape, b.values[0].shape, c.values[0].shape, d.values[0].shape)
Beispiel #2
0
##params = {}
##params['input_size'] = args.input_size
##params['cell_size'] = args.cell_size
##params['strides'] = 2
##params['middle_size'] = args.input_size // params['strides']  #lets divide by strides
##params['cell_kernel_size'] = 5
##params['input_kernel_size'] = 5
##params['middle_kernel_size'] = 5
##params['output_kernel_size'] = 5
##params['rim_iter'] = args.rim_iter
##params['input_activation'] = 'tanh'
##params['output_activation'] = 'linear'
##params['nc'] = nc
##

adam = myAdam(10)
adam10 = myAdam(100)
#
cnn = SimpleUNet(args.cell_size, kernel_size=5)
#optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)

step = tf.Variable(0, trainable=False)
boundaries = [100, 1000]
values = [args.lr, args.lr / 2., args.lr / 5.]
learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries, values)
learning_rate = learning_rate_fn(step)
learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
    args.lr, decay_steps=1000, decay_rate=0.9, staircase=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
params['input_size'] = args.input_size
params['cell_size'] = args.cell_size
params['strides'] = 2
params['middle_size'] = args.input_size // params['strides']  #lets divide by strides
params['cell_kernel_size'] = 5
params['input_kernel_size'] = 5
params['middle_kernel_size'] = 5
params['output_kernel_size'] = 5
params['rim_iter'] = args.rim_iter
params['input_activation'] = 'tanh'
params['output_activation'] = 'linear'
params['nc'] = nc


adamiters, adamiters10 = max(10, params['rim_iter']), max(100, params['rim_iter']*10)
adam = myAdam(adamiters)
adam10 = myAdam(adamiters10)
fid_recon = Recon_DM(nc, bs, a0=a0, af=af, nsteps=nsteps, nbody=args.nbody, lpt_order=args.lpt_order, anneal=True)

suffpath = '_split' + args.suffix
if args.nbody: ofolder = './models/L%04d_N%03d_T%02d%s/'%(bs, nc, nsteps, suffpath)
else: ofolder = './models/L%04d_N%03d_LPT%d%s/'%(bs, nc, args.lpt_order, suffpath)
try: os.makedirs(ofolder)
except Exception as e: print(e)




def get_data(nsims=args.nsims):
    #if args.nbody: dpath = '/project/projectdirs/m3058/chmodi/rim-data/L%04d_N%03d_T%02d/'%(bs, nc, nsteps)
    #else: dpath = '/project/projectdirs/m3058/chmodi/rim-data/L%04d_N%03d_LPT%d/'%(bs, nc, args.lpt_order)