def get_indices(self, keys: mtf.Tensor, query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]: """Generate score and indices for the query.""" score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3]) scores = mtf.einsum([query, keys], output_shape=score_shape) # [b, l, h, 2, n_keys] knn_dim = mtf.Dimension("knn", self.knn) scores, indices = mtf.top_k(scores, score_shape.dims[-1], knn_dim) # [b, l, h, 2, knn] # Computes the top cartesian products and their indices knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2) scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2]) scores2 = mtf.rename_dimension(scores2, "knn", "knn2") out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:]) all_scores = mtf.add(scores1, scores2, output_shape=out_shape) all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:], knn_square_dim) indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2]) indices1 = mtf.multiply(indices1, self.n_keys) indices2 = mtf.rename_dimension(indices2, "knn", "knn2") all_indices = mtf.add(indices1, indices2, output_shape=out_shape) all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:], knn_square_dim) scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1], knn_dim) return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
def testWhileLoopOperation(self): # This test case implements the following: # for i in range(10): # x = x * 2 i = mtf.constant(self.mesh, 0, mtf.Shape([])) cond_fn = lambda i, x: mtf.less(i, 10) body_fn = lambda i, x: [mtf.add(i, 1), mtf.multiply(x, 2)] while_loop_operation = mtf.WhileLoopOperation(cond_fn, body_fn, [i, self.x]) self.assertEqual(while_loop_operation.splittable_dims, frozenset(["a", "b"])) self.assertEqual(while_loop_operation.unsplittable_dims, frozenset())
def recon_model(mesh, datasm, rsdfactor, M0, R0, width, off, istd, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3] # # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) ## state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) final_field = mtf.zeros(mesh, shape=part_shape) final_field = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) ## x = final_field ppars, mpars, kernel = setupfnn() pwts, pbias, pmx, psx = ppars mwts, mbias, mmx, msx, mmy, msy = mpars msy, mmy = msy[0], mmy[0] size = 3 k_dims = [d.shape[0] for d in kv] k_dims = [k_dims[2], k_dims[0], k_dims[1]] tfnc, tfbs = float_to_mtf(nc * 1., mesh, scalar), float_to_mtf(bs, mesh, scalar) x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs], output_dtype=cdtype) x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype) x1d = mtf.add(x1d, -1.) x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype) x2f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype) x12 = x1 - x2 def apply_pwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID') yy = tf.concat([y, y1, y2], axis=-1) yy = yy - pmx yy = yy / psx yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0]) yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1]) yy3 = tf.matmul(yy2, pwts[2]) + pbias[2] pmodel = tf.nn.sigmoid(tf.constant(width) * yy3) return pmodel[..., 0] pmodel = mtf.slicewise( apply_pwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_pwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) def apply_mwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) zz = tf.concat([ tf.expand_dims(x, -1), tf.expand_dims(x1, -1), tf.expand_dims(x2, -1) ], axis=-1) zz = zz - mmx zz = zz / msx zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0]) zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1]) zz3 = tf.matmul(zz2, mwts[2]) + mbias[2] mmodel = zz3 * msy + mmy return mmodel[..., 0] mmodel = mtf.slicewise( apply_mwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_mwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) model = pmodel * mmodel ##RSD below hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables, mesh) mstate = mpm.mtf_indices(hr_field.mesh, shape=part_shape[1:], dtype=tf.float32) X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) massf = mesh_utils.r2c3d(final_field, k_dims, dtype=cdtype) masssmf = mtf.cwise(cwise_fingauss, [massf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) masssm = mesh_utils.c2r3d(masssmf, final_field.shape[-3:], dtype=dtype) masssm = masssm + 1e-5 imasssm = mtf.pow(x, -1.) vzweights = final_state[1] vzweights = mtf.slicewise(lambda x: x[:, :, :, :, -1], [vzweights], output_dtype=tf.float32, output_shape=vzweights.shape[:-1], name='get_vz', splittable_dims=vzweights.shape[1:-1]) print("weights : ", vzweights) momz = mtf.zeros(mesh, shape=part_shape) momz = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, \ halo_size=halo_size, splittables=splittables, mesh=mesh, weights=vzweights) momzf = mesh_utils.r2c3d(momz, k_dims, dtype=cdtype) momzsmf = mtf.cwise(cwise_fingauss, [momzf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) momzsm = mesh_utils.c2r3d(momzsmf, momz.shape[-3:], dtype=dtype) #Shift velzsm = mtf.divide(momzsm, masssm) vz = mcomp.cic_readout_fr(velzsm, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) vz = mtf.multiply(vz, rsdfactor) print("vz : ", vz) Xrsd = mtf.slicewise(lambda x, vz: x + tf.stack( [tf.zeros_like(vz), tf.zeros_like(vz), vz], 4), [X, vzweights], output_dtype=tf.float32, output_shape=X.shape, name='add_vz', splittable_dims=X.shape[1:-1]) print(Xrsd) modelread = mcomp.cic_readout_fr(model, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) modelrsd = mtf.zeros(mesh, shape=part_shape) modelrsd = mcomp.cic_paint_fr(modelrsd, [Xrsd], output_shape=part_shape, hr_shape=hr_shape, \ halo_size=halo_size, splittables=splittables, mesh=mesh, weights=modelread) model = modelrsd print(modelrsd) #Likelihood and prior here mtfdatasm = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(datasm), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3 # Total loss #diff = (model - mtfdata) modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype) modelsmf = mtf.cwise(cwise_fingauss, [modelf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype) ##Anneal M0 = tf.constant(M0) diff = mtf.log(modelsm + M0) - mtf.log(mtfdatasm + M0) if off is not None: mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape) diff = diff + mtfoff if istd is not None: mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape) diff = (diff + mtfoff ) * mtfistd #For some reason, doing things wrong this one else: diff = diff / 0.25 def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior fields = [fieldvar, final_field, model] metrics = [chisq, prior, loss] return fields, metrics, kv