def test_call_shape(self): key_size = 5 value_size = 10 n_keys = 6 n_heads = 2 knn = 3 seq_len = 4 batch = 5 model_dim = mtf.Dimension("model", value_size) seq_dim = mtf.Dimension("length", seq_len) batch_dim = mtf.Dimension("batch", batch) def initialize(shape, dtype): return tf.reshape(1 + tf.range(np.prod(shape), dtype=dtype), shape) self.initializer_mock.side_effect = initialize kv_memory = memory_layers.ProductKeyValueMemory( key_size, n_keys, n_heads, knn) mtf_x = mtf.ones(self.mesh, mtf.Shape([batch_dim, seq_dim, model_dim])) context = mock.MagicMock() context.mesh = self.mesh context.variable_dtype = tf.float32 out_tensor = kv_memory.call(context, mtf_x) # Dimensions should be untouched self.assertEqual(mtf_x.shape, out_tensor.shape)
def test_model(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") seq_len = params["n_ctx"] batch_dim = mtf.Dimension("batch", 1) sequence_dim = mtf.Dimension("sequence", seq_len) features = { 'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32), 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) } # create mask num_mem_kv = params.get('num_mem_kv', 0) length_dim = mtf.Dimension('sequence', seq_len) memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv) embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len) embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) other_features = {} variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) other_features["attn_bias"] = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim with not_raises(Exception): logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) logits = lowering.export_to_tf_tensor(logits)
def lpt_init_single(lr_field, a0, kvec_lr, halo_size, lr_shape, hr_shape, part_shape, antialias=True, order=1, post_filtering=True, cosmology=Planck15): a = a0 batch_dim = lr_field.shape[0] lnc = lr_shape[-1].size # Create particles on the high resolution grid mstate = mesh_ops.mtf_indices(lr_field.mesh, shape=part_shape, dtype=tf.float32) X = mtf.einsum([mtf.ones(lr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) # Reorder the low res FFTs which where transposed# y,z,x grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]] displacement = [] for f in grad_kfield_lr: f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4]+[ mtf.Dimension('sx_block', lnc//hr_shape[1].size), mtf.Dimension('sy_block', lnc//hr_shape[2].size), mtf.Dimension('sz_block', lnc//hr_shape[3].size)]), name='my_reshape', splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size) d = mesh_utils.cic_readout(f, X, halo_size) displacement.append(d) # Readout to particle positions displacement = mtf.stack([ d for d in displacement],"ndim",axis=4) pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0) DX = pt.D1(a) * displacement P = (a ** 2 * pt.f1(a) * pt.E(a)) * DX F = (a ** 2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX # TODO: Implement 2nd order LPT # Moves the particles according to displacement X = X + DX return X, P, F
def hidden_to_logits(self, hidden: mtf.Tensor, context: transformer.Context) -> mtf.Tensor: """Function called by mtf transformer to get the logits. Args: hidden: an mtf.Tensor, hidden model states of the final decoder layer. context: a transformer.Context, the context used for the call to the transformer. Returns: An mtf.Tensor, the logits. """ hidden *= self._output_dim.size**-0.5 component_contexts = mtf.einsum([ mtf.rename_dimension(hidden, self._output_dim.name, self._copy_output_dim.name), self._context_weights, ], reduced_dims=[self._copy_output_dim]) component_contexts = mtf.tanh(component_contexts + self._context_weights_bias) component_logits = mtf.einsum( [component_contexts, self._embedding_weights], reduced_dims=[self._output_dim]) component_logits = self._dropout(component_logits, context) prior_tanh = mtf.tanh( mtf.einsum([self._prior_weights, hidden], reduced_dims=[self._output_dim]) + self._prior_weights_bias) prior_tanh = self._dropout(prior_tanh, context) prior_shared_logits = mtf.einsum([self._prior_gates_vector, hidden], reduced_dims=[self._output_dim]) prior_frequent_vocab_logits = ( mtf.einsum([self._prior_vocab_vector, prior_tanh]) + prior_shared_logits + self._prior_bias) prior_logits = mtf.concat([ prior_frequent_vocab_logits, mtf.ones(self._mesh, mtf.Shape([self._rare_vocab_dim]), dtype=prior_shared_logits.dtype) * prior_shared_logits ], self._vocab_dim.name) if context.train and self._noise_std_dev != 0.0: prior_logits += mtf.random_normal(self._mesh, prior_logits.shape, stddev=self._noise_std_dev) prior_proportions = self._sigmoid_tree(prior_logits) logits = mtf.einsum([component_logits, prior_proportions], reduced_dims=[self._gates_dim]) return self._rearrange_sentinels(logits)
def test_sampling(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", 1) sequence_dim = mtf.Dimension("sequence", 1) inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) inputs = mtf.pad(inputs, [0, 3], sequence_dim.name) # create mask seq_len = params["n_ctx"] num_mem_kv = params.get('num_mem_kv', 0) length_dim = mtf.Dimension('sequence', seq_len) memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv) embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len) embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) other_features = {} other_features["attn_bias"] = biasmask_attn_weights( mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32)) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim params["mode"] = "predict" with not_raises(Exception): samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(), remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) samples = lowering.export_to_tf_tensor(samples)
def test_get_indices(self): key_size = 2 n_keys = 3 product_size = 2 head_size = 2 batch = 2 seq_len = 2 knn = 2 n_key_dim = mtf.Dimension("n_keys", n_keys) key_dim = mtf.Dimension("key", key_size // 2) seq_dim = mtf.Dimension("length", seq_len) batch_dim = mtf.Dimension("batch", batch) head_dim = mtf.Dimension("n_heads", head_size) product_dim = mtf.Dimension("product_key", product_size) knn_dim = mtf.Dimension("knn", knn) query_shape = mtf.Shape( [batch_dim, seq_dim, head_dim, product_dim, key_dim]) keys_shape = mtf.Shape([head_dim, product_dim, n_key_dim, key_dim]) query = mtf.ones(self.mesh, query_shape) keys_vals = [ [ [[4], [1], [2]], [[2], [-1], [2]], ], [ [[1], [2], [5]], [[6], [1], [4]], ], ] # h1: # First scores: # [4, 2] # [2, 2] # Cartesian added scores: # [6, 6] # Indices: # [0, 2] [0*n_k + 0, 0*n_k + 2] # h2: # First scores: # [5, 2] # [6, 4] # Cartesian added scores: # [11, 9] # Indices: # [6, 8] [2*n_k+0, 2*n_k+2] expected_scores = np.broadcast_to(np.array([[6, 6], [11, 9]]), [batch, seq_len, head_size, knn]) expected_indices = np.broadcast_to(np.array([[0, 2], [6, 8]]), [batch, seq_len, head_size, knn]) keys = mtf.constant(self.mesh, keys_vals, keys_shape) pkm = memory_layers.ProductKeyValueMemory(key_size, n_keys, head_size, knn) mtf_scores, mtf_indices = pkm.get_indices(keys, query) # Shapes. expected_shape = mtf.Shape([batch_dim, seq_dim, head_dim, knn_dim]) self.assertEqual(expected_shape, mtf_scores.shape) self.assertEqual(expected_shape, mtf_indices.shape) # Values lowering_s, scores = self._export_to_tf_tensor(mtf_scores) lowering_i, indices = self._export_to_tf_tensor(mtf_indices) self.evaluate(tf.global_variables_initializer()) self.evaluate(lowering_s.copy_masters_to_slices()) self.evaluate(lowering_i.copy_masters_to_slices()) scores, indices = self.evaluate([scores, indices]) self.assertAllEqual(expected_scores, scores) self.assertAllEqual(expected_indices, indices)
def create_dummy_model(mesh, shapes, n_blocks=2, block_param_size_str="2_2", block_repeat_size_str="1_1"): """Creates a dummy model and layer stack with 4-dimensional input.""" assert len(shapes) == 4 outer_batch_size, batch_size, length, d_model = shapes batch_dim = mtf.Dimension("batch", batch_size) outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) length_dim = mtf.Dimension("length", length) block_param_size = list(map(int, block_param_size_str.split("_"))) block_repeat_size = list(map(int, block_repeat_size_str.split("_"))) sublayers_initial = [ transformer.sublayer_dropout, ] sublayers_per_layer = [ transformer.sublayer_rms_norm, transformer.sublayer_call_layer, transformer.sublayer_dropout, transformer.sublayer_residual, ] sublayers_final = [ transformer.sublayer_rms_norm, transformer.sublayer_dropout, ] submodules = [ transformer_layers.SelfAttention(), transformer_layers.DenseReluDense() ] n_sublayers = np.array(block_param_size).prod() layers = submodules * n_sublayers layer_stack = funnel_transformer.FunnelTransformerLayerStack( layers=layers, n_blocks=n_blocks, block_param_size=block_param_size, block_repeat_size=block_repeat_size, sublayers_initial=sublayers_initial, sublayers_per_layer=sublayers_per_layer, sublayers_final=sublayers_final) model = transformer.Unitransformer(input_vocab_size=10, output_vocab_size=10, autoregressive=False, max_length=8, d_model=d_model, layer_stack=layer_stack) context = transformer.Context(model=model, mesh=mesh, batch_dims=[batch_dim, outer_batch_dim], length_dim=length_dim, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=mtf.ones(mesh, mtf.Shape([length_dim ])), position=mtf.range(mesh, length_dim, dtype=tf.int32)) return layer_stack, context
def recon_model(mesh, data, bparams, ipkerror, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ b1, b2, bs2 = bparams kerror, perror = ipkerror[0].astype(np.float32), ipkerror[1].astype( np.float32) if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('..//data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('..//data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) pke_dim = mtf.Dimension("epk", len(perror)) pkerror = mtf.import_tf_tensor(mesh, perror.astype(npdtype), shape=[pke_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3] # # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) # paint the field final_field = mtf.zeros(mesh, shape=part_shape) final_field = mcomp.cic_paint_fr(final_field, final_state, part_shape, hr_shape, halo_size, splittables, mesh) ## #Get the fields for bias hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables, mesh) mstate = mpm.mtf_indices(hr_field.mesh, shape=part_shape[1:], dtype=tf.float32) X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] tfnc, tfbs = cswisef.float_to_mtf(nc * 1., mesh, scalar), cswisef.float_to_mtf( bs, mesh, scalar) # initc = fieldvar d0 = initc - mtf.reduce_mean(initc) # d2 = initc * initc d2 = d2 - mtf.reduce_mean(d2) # cfield = mesh_utils.r2c3d(d0, k_dims_pr, dtype=cdtype) shearfield = mtf.zeros(mesh, shape=part_shape) shearfield = shear(shearfield, cfield, kv, tfnc, tfbs) s2 = shearfield - mtf.reduce_mean(shearfield) dread = mcomp.cic_readout_fr(d0, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) d2read = mcomp.cic_readout_fr(d2, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) s2read = mcomp.cic_readout_fr(s2, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) ed, ed2, es2 = mtf.zeros(mesh, shape=part_shape), mtf.zeros( mesh, shape=part_shape), mtf.zeros(mesh, shape=part_shape) ed = mcomp.cic_paint_fr(ed, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=dread) ed2 = mcomp.cic_paint_fr(ed2, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=d2read) es2 = mcomp.cic_paint_fr(es2, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=s2read) model = ed * b1 + ed2 * b2 + es2 * bs2 mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) diff = model - mtfdata # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #* nc**3 # Total loss cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) def _cwise_diff(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid(x=kk, x_ref_min=kerror.min(), x_ref_max=kerror.max(), y_ref=pk) priormesh = tf.reshape(pkmesh, kshape) priormesh = tf.cast(priormesh**0.5, kfield.dtype) return kfield / priormesh cdiff = mtf.cwise(_cwise_diff, [cdiff, pkerror] + kv, output_dtype=cdtype) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior fields = [fieldvar, final_field, model] metrics = [chisq, prior, loss] return fields, metrics, kv
def recon_model(mesh, datasm, rsdfactor, M0, R0, width, off, istd, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3] # # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) ## state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) final_field = mtf.zeros(mesh, shape=part_shape) final_field = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) ## x = final_field ppars, mpars, kernel = setupfnn() pwts, pbias, pmx, psx = ppars mwts, mbias, mmx, msx, mmy, msy = mpars msy, mmy = msy[0], mmy[0] size = 3 k_dims = [d.shape[0] for d in kv] k_dims = [k_dims[2], k_dims[0], k_dims[1]] tfnc, tfbs = float_to_mtf(nc * 1., mesh, scalar), float_to_mtf(bs, mesh, scalar) x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs], output_dtype=cdtype) x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype) x1d = mtf.add(x1d, -1.) x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype) x2f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype) x12 = x1 - x2 def apply_pwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID') yy = tf.concat([y, y1, y2], axis=-1) yy = yy - pmx yy = yy / psx yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0]) yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1]) yy3 = tf.matmul(yy2, pwts[2]) + pbias[2] pmodel = tf.nn.sigmoid(tf.constant(width) * yy3) return pmodel[..., 0] pmodel = mtf.slicewise( apply_pwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_pwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) def apply_mwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) zz = tf.concat([ tf.expand_dims(x, -1), tf.expand_dims(x1, -1), tf.expand_dims(x2, -1) ], axis=-1) zz = zz - mmx zz = zz / msx zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0]) zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1]) zz3 = tf.matmul(zz2, mwts[2]) + mbias[2] mmodel = zz3 * msy + mmy return mmodel[..., 0] mmodel = mtf.slicewise( apply_mwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_mwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) model = pmodel * mmodel ##RSD below hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables, mesh) mstate = mpm.mtf_indices(hr_field.mesh, shape=part_shape[1:], dtype=tf.float32) X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) massf = mesh_utils.r2c3d(final_field, k_dims, dtype=cdtype) masssmf = mtf.cwise(cwise_fingauss, [massf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) masssm = mesh_utils.c2r3d(masssmf, final_field.shape[-3:], dtype=dtype) masssm = masssm + 1e-5 imasssm = mtf.pow(x, -1.) vzweights = final_state[1] vzweights = mtf.slicewise(lambda x: x[:, :, :, :, -1], [vzweights], output_dtype=tf.float32, output_shape=vzweights.shape[:-1], name='get_vz', splittable_dims=vzweights.shape[1:-1]) print("weights : ", vzweights) momz = mtf.zeros(mesh, shape=part_shape) momz = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, \ halo_size=halo_size, splittables=splittables, mesh=mesh, weights=vzweights) momzf = mesh_utils.r2c3d(momz, k_dims, dtype=cdtype) momzsmf = mtf.cwise(cwise_fingauss, [momzf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) momzsm = mesh_utils.c2r3d(momzsmf, momz.shape[-3:], dtype=dtype) #Shift velzsm = mtf.divide(momzsm, masssm) vz = mcomp.cic_readout_fr(velzsm, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) vz = mtf.multiply(vz, rsdfactor) print("vz : ", vz) Xrsd = mtf.slicewise(lambda x, vz: x + tf.stack( [tf.zeros_like(vz), tf.zeros_like(vz), vz], 4), [X, vzweights], output_dtype=tf.float32, output_shape=X.shape, name='add_vz', splittable_dims=X.shape[1:-1]) print(Xrsd) modelread = mcomp.cic_readout_fr(model, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) modelrsd = mtf.zeros(mesh, shape=part_shape) modelrsd = mcomp.cic_paint_fr(modelrsd, [Xrsd], output_shape=part_shape, hr_shape=hr_shape, \ halo_size=halo_size, splittables=splittables, mesh=mesh, weights=modelread) model = modelrsd print(modelrsd) #Likelihood and prior here mtfdatasm = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(datasm), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3 # Total loss #diff = (model - mtfdata) modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype) modelsmf = mtf.cwise(cwise_fingauss, [modelf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype) ##Anneal M0 = tf.constant(M0) diff = mtf.log(modelsm + M0) - mtf.log(mtfdatasm + M0) if off is not None: mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape) diff = diff + mtfoff if istd is not None: mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape) diff = (diff + mtfoff ) * mtfistd #For some reason, doing things wrong this one else: diff = diff / 0.25 def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior fields = [fieldvar, final_field, model] metrics = [chisq, prior, loss] return fields, metrics, kv