def nbody_fn(mesh, klin, plin, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Pyramid N-body function """ stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize # Parameters of the large scales decomposition downsampling_factor = FLAGS.dsample lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor( mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor( mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor( mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False) kx_lr = mtf.import_tf_tensor( mesh, kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor( mesh, kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor( mesh, kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False) kx_hr = mtf.import_tf_tensor( mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor( mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor( mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Compute initial initial conditions distributed initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) # Reshaping array into high resolution mesh field = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [initc], output_dtype=tf.float32, output_shape=hr_shape, name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise( lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) state = mtfpm.lpt_init( low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True, ) final_state = mtfpm.nbody( state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return initc, final_field
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype(npdtype), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype(npdtype), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype(npdtype), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False, dtype=npdtype) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False, dtype=npdtype) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype(npdtype), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype(npdtype), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype(npdtype), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] # kvec for prior blocks prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x) prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y) prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z) kvec_pr = flowpm.kernels.fftk( [nc // n_block_x, nc // n_block_y, nc // n_block_z], symmetric=False, dtype=npdtype) kx_pr = mtf.import_tf_tensor(mesh, kvec_pr[0].squeeze().astype(npdtype), shape=[prior_sx_dim]) ky_pr = mtf.import_tf_tensor(mesh, kvec_pr[1].squeeze().astype(npdtype), shape=[prior_sy_dim]) kz_pr = mtf.import_tf_tensor(mesh, kvec_pr[2].squeeze().astype(npdtype), shape=[prior_sz_dim]) kv_pr = [ky_pr, kz_pr, kx_pr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] ## Compute initial initial conditions distributed fieldvar = mtf.get_variable(mesh, 'linear', hr_shape) input_field = tf.placeholder(data.dtype, [ batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x, nc // n_block_y, nc // n_block_z ]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape) linearop = mtf.assign(fieldvar, mtfinp) # field = fieldvar initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) # for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=initc.dtype, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) final_state = mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) else: final_state = mtfpm.lpt_init(low, high, stages[-1], kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv_pr] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 # Total loss diff = (final_field - mtfdata) R0 = tf.placeholder(tf.float32, shape=()) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype) var_grads = [ mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype) ] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0
def lpt_prototype(mesh, initial_conditions, derivs, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ stages = np.linspace(a0, a, nsteps, endpoint=True) lap, grad_x, grad_y, grad_z = derivs klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize # Parameters of the large scales decomposition downsampling_factor = FLAGS.dsample lnc = nc // 2**downsampling_factor # fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim]) kv_hr = [kx_hr, ky_hr, kz_hr] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] initc = tf.reshape( initial_conditions, [1, n_block_x, nc // n_block_x, n_block_y, nc // n_block_y, 1, nc]) initc = tf.transpose(initc, [0, 1, 3, 5, 2, 4, 6]) field = mtf.import_tf_tensor(mesh, initc, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) # Hack to handle reshape acrosss multiple dimensions #low = mtf.reshape(low, [batch_dim, x_dim, low.shape[2], low.shape[5], z_dim]) #low = mtf.reshape(low, lr_shape) state = mtfpm.lpt_init( low, high, a0, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, k_dims, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True, ) # Here we can run our nbody final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return final_field
def nbody_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ # Setup parameters bs = FLAGS.box_size nc = FLAGS.cube_size batch_size = FLAGS.batch_size a0 = FLAGS.a0 a = 1.0 nsteps = FLAGS.pm_steps # Compute a few things first, using simple tensorflow klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Initialize the integration steps stages = np.linspace(FLAGS.a0, 1.0, FLAGS.pm_steps, endpoint=True) # Generate a batch of 3D initial conditions initial_conditions = flowpm.linear_field( nc, # size of the cube bs, # Physical size of the cube ipklin, # Initial power spectrum batch_size=batch_size) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) from flowpm.kernels import laplace_kernel, gradient_kernel lap = tf.cast(laplace_kernel(kvec), tf.complex64) grad_x = gradient_kernel(kvec, 0) grad_y = gradient_kernel(kvec, 1) grad_z = gradient_kernel(kvec, 2) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = 8 n_block_y = 4 n_block_z = 1 halo_size = 16 # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) x_dim_t = mtf.Dimension("x_lr_t", lnc) y_dim_t = mtf.Dimension("y_lr_t", lnc) z_dim_t = mtf.Dimension("z_lr_t", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor, shape=[x_dim_t]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor, shape=[y_dim_t]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor, shape=[z_dim_t]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] initc = tf.reshape(initial_conditions, [ batch_size, n_block_x, nc // n_block_x, n_block_y, nc // n_block_y, 1, nc ]) initc = tf.transpose(initc, [0, 1, 3, 5, 2, 4, 6]) field = mtf.import_tf_tensor(mesh, initc, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) state = mtfpm.lpt_init( low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return final_field