def lpt_init_single(lr_field, a0, kvec_lr, halo_size, lr_shape, hr_shape, part_shape, antialias=True, order=1, post_filtering=True, cosmology=Planck15): a = a0 batch_dim = lr_field.shape[0] lnc = lr_shape[-1].size # Create particles on the high resolution grid mstate = mesh_ops.mtf_indices(lr_field.mesh, shape=part_shape, dtype=tf.float32) X = mtf.einsum([mtf.ones(lr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) # Reorder the low res FFTs which where transposed# y,z,x grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]] displacement = [] for f in grad_kfield_lr: f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4]+[ mtf.Dimension('sx_block', lnc//hr_shape[1].size), mtf.Dimension('sy_block', lnc//hr_shape[2].size), mtf.Dimension('sz_block', lnc//hr_shape[3].size)]), name='my_reshape', splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size) d = mesh_utils.cic_readout(f, X, halo_size) displacement.append(d) # Readout to particle positions displacement = mtf.stack([ d for d in displacement],"ndim",axis=4) pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0) DX = pt.D1(a) * displacement P = (a ** 2 * pt.f1(a) * pt.E(a)) * DX F = (a ** 2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX # TODO: Implement 2nd order LPT # Moves the particles according to displacement X = X + DX return X, P, F
def _get_decoder_inputs(self, context): """Computes the inputs to the decoder when using transparent attention. We must cache on the context in order to ensure that we are not replicating variables when the layer's call function is called in different tf variable scopes. Args: context: a Context Returns: a list containing `self.num_decoder_modules` of tensors with shape [<batch_dims>, length_dim, output_vocab_dim] """ if hasattr(context, "decoder_layers_per_module"): return context.decoder_layers_per_module encoder_layer_outputs = [ mtf.layers.rename_length_to_memory_length(output) for output in context.encoder_layer_outputs ] layers_per_module = self.layers_per_encoder_module encoder_module_outputs_dim = mtf.Dimension( "encoder_module_outputs", size=self.encoder_num_modules + 1) decoder_module_inputs_dim = mtf.Dimension( "decoder_module_inputs", size=self.decoder_num_modules) encoder_module_outputs = mtf.stack( [encoder_layer_outputs[0]] + encoder_layer_outputs[layers_per_module::layers_per_module], dim_name="encoder_module_outputs") w = mtf.get_variable( context.mesh, "w", mtf.Shape([encoder_module_outputs_dim, decoder_module_inputs_dim]), initializer=tf.random_normal_initializer( stddev=(encoder_module_outputs_dim.size * decoder_module_inputs_dim.size)**-0.5), dtype=context.variable_dtype) if context.train and self.dropout_rate != 0.0: w = mtf.dropout(w, 1.0 - self.dropout_rate) s = mtf.softmax(w, reduced_dim=encoder_module_outputs_dim) z = mtf.einsum([s, encoder_module_outputs], reduced_dims=[encoder_module_outputs_dim]) input_per_decoder = mtf.split( z, split_dim=decoder_module_inputs_dim, num_or_size_splits=decoder_module_inputs_dim.size) context.decoder_layers_per_module = [ mtf.reshape(inpt, z.shape.dims[1:]) for inpt in input_per_decoder ] return context.decoder_layers_per_module
def compute_loss(self, decoder: transformer.Unitransformer, hidden: mtf.Tensor, targets: mtf.Tensor, context: transformer.Context) -> mtf.Tensor: """Returns the loss without computing a softmax over the entire vocab.""" loss = 0 tail_cluster_masks = [] for cluster in self._tail_clusters: cluster_mask = cluster.get_cluster_mask(targets) tail_cluster_masks.append(cluster_mask) if cluster.length_projection_factor == 1: targets_in_cluster = mtf.where(cluster_mask, targets, 0) hidden_in_cluster = mtf.where(cluster_mask, hidden, 0) else: # TODO(mmatena): Unfold the batch dim to get a super long sequence dim # to reduce the risk of overflowing the projection. proj_to_cluster_len = cluster.get_project_to_cluster_length( cluster_mask, dtype=targets.dtype) targets_in_cluster = mtf.einsum( [proj_to_cluster_len, targets], reduced_dims=[targets.shape.get_dim_by_name("length")]) hidden_in_cluster = mtf.einsum( [mtf.cast(proj_to_cluster_len, hidden.dtype), hidden], reduced_dims=[hidden.shape.get_dim_by_name("length")]) loss += cluster.compute_loss(decoder, hidden_in_cluster, targets_in_cluster, context) tail_clusters_dim = mtf.Dimension("tail_clusters", len(tail_cluster_masks)) tail_node_targets = mtf.reduce_sum( mtf.stack([(self._head_cluster.end_token_id + i) * mtf.cast(mask, targets.dtype) for i, mask in enumerate(tail_cluster_masks)], tail_clusters_dim.name), reduced_dim=tail_clusters_dim) head_targets = mtf.where(mtf.cast(tail_node_targets, tf.bool), tail_node_targets, targets) loss += self._head_cluster.compute_loss(decoder, hidden, head_targets, context) return loss
def force_single(state, lr_shape, hr_shape, kvec_lr, halo_size, cosmology=Planck15, pm_nc_factor=1, **kwargs): """ Estimate force on the particles given a state. Parameters: ----------- state: tensor Input state tensor of shape (3, batch_size, npart, 3) boxsize: float Size of the simulation volume (Mpc/h) TODO: check units cosmology: astropy.cosmology Cosmology object pm_nc_factor: int TODO: @modichirag please add doc """ X, P, F = state #TODO: support different factor assert pm_nc_factor == 1 lnc = lr_shape[-1].size part_shape = X.shape # Paint the particles on the high resolution mesh field = mtf.zeros(X.mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) field = mesh_utils.cic_paint(field, X, halo_size) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: field = mtf.slice(field, halo_size, block_size_dim.size, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) # Reorder the low res FFTs which where transposed# y,z,x kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]] displacement = [] for f in kfield_lr: f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4] + [ mtf.Dimension('sx_block', lnc // hr_shape[1].size), mtf.Dimension('sy_block', lnc // hr_shape[2].size), mtf.Dimension('sz_block', lnc // hr_shape[3].size) ]), name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size) d = mesh_utils.cic_readout(f, X, halo_size) displacement.append(d) # Readout the force to particle positions F = mtf.stack([d for d in displacement], "ndim", axis=4) F = F * 1.5 * cosmology.Om0 return X, P, F
def force(state, lr_shape, hr_shape, kvec_lr, kvec_hr, halo_size, cosmology=Planck15, downsampling_factor=2, pm_nc_factor=1, antialias=True, **kwargs): """ Estimate force on the particles given a state. Parameters: ----------- state: tensor Input state tensor of shape (3, batch_size, npart, 3) boxsize: float Size of the simulation volume (Mpc/h) TODO: check units cosmology: astropy.cosmology Cosmology object pm_nc_factor: int TODO: @modichirag please add doc """ X, P, F = state #TODO: support different factor assert pm_nc_factor == 1 lnc = lr_shape[-1].size part_shape = X.shape k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_hr = [d.shape[0] for d in kvec_hr] # Reorder the FFTs which where transposed# y,z,x k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]] # Paint the particles on the high resolution mesh field = mtf.zeros(X.mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) field = mesh_utils.cic_paint(field, X, halo_size) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size) # Split the field into low and high resolution field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) hr_field = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr) kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield, kvec_lr, r_split=0) kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield, kvec_hr, r_split=0) kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr) # Reorder the low res FFTs which where transposed# y,z,x kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]] kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]] displacement = [] for f, g in zip(kfield_lr, kfield_hr): f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4] + [ mtf.Dimension('sx_block', lnc // hr_shape[1].size), mtf.Dimension('sy_block', lnc // hr_shape[2].size), mtf.Dimension('sz_block', lnc // hr_shape[3].size) ]), name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [ halo_size // 2**downsampling_factor, halo_size // 2**downsampling_factor ], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size // 2**downsampling_factor) f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)]) f = mesh_utils.upsample(f, downsampling_factor) f = mtf.reshape(f, f.shape[:-1]) g = mesh_utils.c2r3d(g, f.shape[-3:]) high_shape = g.shape # And now we remove the large scales g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)]) _low = mesh_utils.downsample(g, downsampling_factor, antialias=antialias) g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor), g.shape) g = mtf.reshape(g, high_shape) d = mesh_utils.cic_readout(f + g, X, halo_size) displacement.append(d) # Readout the force to particle positions F = mtf.stack([d for d in displacement], "ndim", axis=4) F = F * 1.5 * cosmology.Om0 return X, P, F