def update_context(self, context, x, pool_dim_name): """Update the length dimension, sequence_id and position information.""" pooled_seq_length = x.shape.get_dim_by_name(pool_dim_name).size # For position, we slice the first `pooled_seq_length` indices instead of # striding. This ensures that the 3rd position before the pooling becomes # 2nd position after pooling instead of remembering its position before # pooling. new_context_position = mtf.slice( context.position, begin=0, size=pooled_seq_length, slice_dim_name=pool_dim_name) context.position = new_context_position pooled_seq_length = x.shape.get_dim_by_name(pool_dim_name).size new_length_dim = mtf.Dimension( name=pool_dim_name, size=pooled_seq_length) new_sequence_id = mtf.stride_tensor_1d( context.sequence_id, pool_dim=context.length_dim, pool_size=self.pooling_size) context.length_dim = new_length_dim context.sequence_id = new_sequence_id
def halo_reduce(x, blocks_dim, block_size_dim, halo_size, wrap=True): """Reduce each block with the margins of adjacent blocks. Get left and right blocks_dim and sum overlap along block_size_dim. Only supports halo size smaller than block_size/2 Args: x: a Tensor. blocks_dim: a Dimension in x.shape block_size_dim: a Dimension in x.shape halo_size: an integer wrap: a boolean Returns: a Tensor with the same shape as x, other than in block_size_dim, whose size is increased by 2*halo_size. """ if halo_size == 0: return x block_size = block_size_dim.size assert halo_size <= block_size // 2 left_margin = mtf.slice(x, 0, 2 * halo_size, block_size_dim.name) right_margin = mtf.slice(x, block_size_dim.size - 2 * halo_size, 2 * halo_size, block_size_dim.name) center = mtf.slice(x, 2 * halo_size, block_size_dim.size - 4 * halo_size, block_size_dim.name) # Perform halo exchange sum margins left = mtf.shift(right_margin, 1, blocks_dim, wrap) + left_margin right = mtf.shift(left_margin, -1, blocks_dim, wrap) + right_margin # Recompose block left = mtf.pad(left, [0, block_size_dim.size - 2 * halo_size], block_size_dim.name) right = mtf.pad(right, [block_size_dim.size - 2 * halo_size, 0], block_size_dim.name) center = mtf.pad(center, [2 * halo_size, 2 * halo_size], block_size_dim.name) x = left + center + right return x
def call(self, context, x, losses=None): """Call the layer.""" if self.canine_mode: # This is the canine-like ByT5 + LASC baseline in paper. return self.call_canine_encoder(context, x, losses=losses) if self.conv_type: if self.conv_type == "conv1d": tf.logging.info("Using 1d conv") tmp_output = mtf.Dimension("tmp_dim", x.shape[-1].size) orig_dim = x.shape[-1] x = mtf.layers.conv1d(x, tmp_output, filter_size=self.filter_size, stride=1) x = mtf.rename_dimension(x, "tmp_dim", orig_dim.name) tf.logging.info(x) if self.norm: x = sublayer_rms_norm(x, None, context) o = x olength = o.shape.get_dim_by_name("length") o = custom_attention.gradient_based_subword_tokenization( o, olength, downsample=self.downsample_query, use_offsets=self.use_offsets, consider_chars_as_blocks=self.consider_chars_as_blocks, use_block_pos_embedding=self.use_block_pos_embedding, memory_embeddings=self.num_memory_slots, context=context, block_mixing_mode=self.block_mixing_mode, activation=self.rank_activation, downsample_function=self.gbst_pool) new_length_dim = o.shape.get_dim_by_name("length") context.length_dim = new_length_dim new_context_position = context.get_position() context.position = new_context_position context.sequence_id = mtf.slice(context.sequence_id, begin=0, size=new_length_dim.size, slice_dim_name=new_length_dim.name) if self.use_ffn: # not actually used in Charformer. tf.logging.info("Using FFN") o2 = self.ffn.call(context, o) o = o + o2 if self.norm: o = sublayer_rms_norm(o, None, context) olength = o.shape.get_dim_by_name("length") return o, context
def call_canine_encoder(self, context, x, losses=None): """Call Canine baseline encoder (Byte level T5 + LASC in paper).""" # local attention params = self.make_params(context) q = params.compute_q(x) if self.shared_kv: kv = params.compute_kv(x) k = kv v = kv else: k = params.compute_k(x) v = params.compute_v(x) # local attention output_shape = x.shape x = custom_attention.local_attention_1d( q, k, v, length_dim=context.length_dim, length_dim_num_splits=1, key_dim=self.kv_dim, value_dim=self.kv_dim, fully_autoregressive=False, radius=self.radius, sequence_id=context.sequence_id, write_priority=context.write_priority, read_priority=context.read_priority, context=context, attention_kwargs=self.attention_kwargs_from_context(context)) o = params.compute_output(x, output_shape=output_shape) # strided convolutions tmp_output = mtf.Dimension("tmp_dim", o.shape[-1].size) # downsample query args is reused here for "r" o = mtf.layers.conv1d(o, tmp_output, filter_size=self.filter_size, stride=int(self.downsample_query)) o = mtf.rename_dimension(o, "tmp_dim", "d_model") tf.logging.info(o) new_length_dim = o.shape.get_dim_by_name("length") context.length_dim = new_length_dim new_context_position = context.get_position() context.position = new_context_position context.sequence_id = mtf.slice(context.sequence_id, begin=0, size=new_length_dim.size, slice_dim_name=new_length_dim.name) return o, context
def downsample_hr_to_lr(field, lr_shape, hr_shape, downsampling_factor, halo_size, splittables, mesh): # Reshaping array into high resolution mesh field = mtf.reshape(field, field.shape+[mtf.Dimension('h_dim', 1)]) low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size//2**downsampling_factor, block_size_dim.size//2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:,0,0,0], [low], output_dtype=field.dtype, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=splittables) return low
def cic_paint_fr(field, state, output_shape, hr_shape, halo_size, splittables, mesh, weights=None): '''paint the position from state to a field of batch+3D tensor Ops performed : - reshape to hr_shape - pad - paint - halo_reduce - slice to remove pad - reshape to output_shape ''' lnc = field.shape[-1].size field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1), [field], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4]+[ mtf.Dimension('sx_block', lnc//hr_shape[1].size), mtf.Dimension('sy_block', lnc//hr_shape[2].size), mtf.Dimension('sz_block', lnc//hr_shape[3].size)]), name='my_reshape', splittable_dims=splittables) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) field = mesh_utils.cic_paint(field, state[0], halo_size, weights) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: field = mtf.slice(field, halo_size, block_size_dim.size, block_size_dim.name) field = mtf.slicewise(lambda x: x[:,0,0,0], [field], output_dtype=field.dtype, output_shape=output_shape, name='my_dumb_reshape', splittable_dims=splittables) return field
def synthetic_attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, synthesize=True, synthesize_mode="random_plus_alpha", factorized_dim=16, max_length=512, context=None): """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743). key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor synthesize: flag to use synthetic attention or not synthesize_mode: which variant of synthesizer to use factorized_dim: factorized dim for synthesizers max_length: max length of input sequence context: context since we need context mode Returns: Tensor with shape q.shape - key_dim + value_dim """ if synthesize: num_heads = v.shape.get_dim_by_name("heads") tf.logging.info("Using synthesizer") if synthesize_mode == "random": tf.logging.info("Using Random Synthesizers") r_shape = mtf.Shape([mtf.Dimension("length", max_length), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", max_length)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r r_shape = logits.shape elif synthesize_mode == "factorized": tf.logging.info("Using Factorized Random Synthesizers") k = factorized_dim r1_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r2_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r1 = mtf.get_variable(context.mesh, "R1", r1_shape, initializer=None, dtype=context.variable_dtype) r2 = mtf.get_variable(context.mesh, "R2", r2_shape, initializer=None, dtype=context.variable_dtype) r = mtf.einsum([r1, r2], r_shape) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r elif synthesize_mode == "dense_minus": # Dense Synthesizer Model tmp_dim = mtf.Dimension("memory_length", max_length) logits = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) logits = mtf.slice(logits, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") logits = mtf.slice(logits, 0, length_dim.size, "length") elif synthesize_mode == "random_plus_alpha" or \ synthesize_mode == "random_plus": # Mixture Random Synthesizer with learnable Alpha tf.logging.info("Using Random Plus Alpha") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) num_heads = logits.shape.get_dim_by_name("heads") r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, length_dim.name) if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r elif synthesize_mode == "dense_plus_alpha" or \ synthesize_mode == "dense_plus": # Mixture Dense Synthesizer with learnable alpha tf.logging.info("Using Dense Plus Alpha Scaling") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) tmp_dim = mtf.Dimension("memory_length", 512) r = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) if synthesize and "plus" not in synthesize_mode: if synthesize_mode == "dense_minus": outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim]) else: outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim]) else: outputs_shape = q.shape - [key_dim] + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def nbody_prototype(mesh, infield=False, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize # Parameters of the large scales decomposition fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin, shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation ## Compute initial initial conditions distributed input_field = tf.placeholder(dtype, [batch_size, nc, nc, nc]) if infield: initc = mtf.import_tf_tensor(mesh, input_field, shape=part_shape) else: initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( initc, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return initc, final_field, input_field
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels): """Builds the UNet model graph, train op and eval metrics. Args: mesh: a MeshTensorflow.mesh object. mesh_impl: a mesh implementation, such as SimdMeshImpl and PlacementMeshImpl. dataset_str: a string of either train or eval. This is used for batch_norm. images: a laid out Tensor with shape [batch, x, y, num_channels] or [batch, x, y, z, num_channels]. labels: a laid out Tensor with shape [batch, x, y, num_classes] or [batch, x, y, z, num_classes]. Returns: Prediction and loss. """ is_training = (dataset_str == 'train') if dataset_str == 'train': batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train) else: assert dataset_str == 'eval' batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval) image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block) image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block) image_sx_dim = mtf.Dimension('image_sx_block', FLAGS.ct_resolution // FLAGS.image_nx_block) image_sy_dim = mtf.Dimension('image_sy_block', FLAGS.ct_resolution // FLAGS.image_ny_block) image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution) image_c_dim = mtf.Dimension('image_c', FLAGS.image_c) label_c_dim = mtf.Dimension('label_c', FLAGS.label_c) mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str) mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype) variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype) # Import input features. x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images), mtf_images_shape) x = mtf.cast(x, mtf_dtype) # Import ground truth labels. t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels), mtf_labels_shape) t = mtf.cast(t, mtf_dtype) # Transpose the blocks. if FLAGS.sampled_2d_slices: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, label_c_dim ]) else: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, label_c_dim ]) # Network. levels = [] all_bn_update_ops = [] # add levels with convolution or down-sampling for depth in range(FLAGS.network_depth): for n_conv in range(FLAGS.n_conv_per_block): if depth == 0 and n_conv == 0: # no dropout in 1st layer. dropout_keep_p = 1.0 else: dropout_keep_p = FLAGS.dropout_keep_p x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_down_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) levels.append(x) if depth < FLAGS.network_depth - 1: if FLAGS.sampled_2d_slices: x = mtf.layers.max_pool2d(x, ksize=(2, 2)) else: x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2)) # add levels with up-convolution or up-sampling for depth in range(FLAGS.network_depth - 1)[::-1]: x = deconv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, 'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1), variable_dtype, 'deconv_{}_0'.format(depth)) x = mtf.concat([x, levels[depth]], concat_dim_name='conv_{}_{}'.format( depth, FLAGS.n_conv_per_block - 1)) for n_conv in range(FLAGS.n_conv_per_block): x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_up_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) # no dropout in the final layer. if FLAGS.sampled_2d_slices: y = mtf.layers.conv2d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1), strides=(1, 1), padding='SAME', h_blocks_dim=image_nx_dim, w_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) else: y = mtf.layers.conv3d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME', d_blocks_dim=image_nx_dim, h_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) # use mtf.constant to make sure there is no CPU-side constants. def scalar(v, dtype): return mtf.constant(mesh, v, shape=[], dtype=dtype) argmax_t = mtf.argmax(t, label_c_dim) liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype) lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype) argmax_y = mtf.argmax(y, label_c_dim) lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype) # summary of class ratios. lesion_pred_ratio = mtf.reduce_mean(lesion_y) lesion_label_ratio = mtf.reduce_mean(lesion_t) # summary of accuracy. accuracy = mtf.reduce_mean( mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype)) # Cross-entropy loss. Up-weight the liver region. pixel_loss = mtf.layers.softmax_cross_entropy_with_logits( y, t, label_c_dim) pixel_weight = scalar(1, mtf_dtype) + \ liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \ lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight, mtf_dtype) loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight) # Dice loss y_prob = mtf.softmax(y, reduced_dim=label_c_dim) lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'), reduced_dim=mtf.Dimension('label_c', 1)) prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t, output_shape=mtf.Shape([batch_dim])) prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t, output_shape=mtf.Shape([batch_dim])) loss_dice_per_case = mtf.reduce_mean( scalar(-2, mtf_dtype) * prob_intersect / (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype))) loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum( prob_intersect) / (mtf.reduce_sum(prob_area_sum) + scalar(FLAGS.dice_epsilon, mtf_dtype)) loss_dice = (loss_dice_per_case + loss_dice_global) * scalar( 0.5, mtf_dtype) loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar( 1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen intersect = mtf.reduce_sum(lesion_y * lesion_t, output_shape=mtf.Shape([batch_dim])) area_sum = mtf.reduce_sum(lesion_y + lesion_t, output_shape=mtf.Shape([batch_dim])) # summary of dice. dice_per_case = mtf.reduce_mean( scalar(2, mtf_dtype) * intersect / (area_sum + scalar(0.000001, mtf_dtype))) dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / ( mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype)) eval_metrics = { 'lesion_pred_ratio': lesion_pred_ratio, 'lesion_label_ratio': lesion_label_ratio, 'accuracy_of_all_classes': accuracy, 'lesion_dice_per_case': dice_per_case, 'lesion_dice_global': dice_global, 'loss_xen': loss_xen, 'loss_dice': loss_dice, 'loss_dice_per_case': loss_dice_per_case, 'loss_dice_global': loss_dice_global, } if FLAGS.sampled_2d_slices: y_prob_downsampled = mtf.layers.avg_pool2d( y_prob, ksize=(FLAGS.pred_downsample, ) * 2) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool2d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 2) else: y_prob_downsampled = mtf.layers.avg_pool3d( y_prob, ksize=(FLAGS.pred_downsample, ) * 3) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool3d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 3) liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c') lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c') preds = [ mtf.reduce_sum(liver_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)), mtf.reduce_sum(lesion_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)) ] if FLAGS.output_ground_truth: preds.append( mtf.reduce_sum(lesion_gt_downsampled, reduced_dim=mtf.Dimension('label_c', 1))) preds.extend([intersect, area_sum]) return preds, loss, eval_metrics, all_bn_update_ops
def nbody_fn(mesh, klin, plin, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Pyramid N-body function """ stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize # Parameters of the large scales decomposition downsampling_factor = FLAGS.dsample lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor( mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor( mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor( mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False) kx_lr = mtf.import_tf_tensor( mesh, kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor( mesh, kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor( mesh, kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False) kx_hr = mtf.import_tf_tensor( mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor( mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor( mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Compute initial initial conditions distributed initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) # Reshaping array into high resolution mesh field = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [initc], output_dtype=tf.float32, output_shape=hr_shape, name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise( lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) state = mtfpm.lpt_init( low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True, ) final_state = mtfpm.nbody( state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return initc, final_field
def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ # Setup parameters bs = FLAGS.box_size nc = FLAGS.cube_size batch_size = FLAGS.batch_size a0 = FLAGS.a0 a = 1.0 nsteps = FLAGS.pm_steps # Compute a few things first, using simple tensorflow klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Initialize the integration steps stages = np.linspace(FLAGS.a0, 1.0, FLAGS.pm_steps, endpoint=True) # Generate a batch of 3D initial conditions initial_conditions = flowpm.linear_field( nc, # size of the cube bs, # Physical size of the cube ipklin, # Initial power spectrum batch_size=batch_size) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) from flowpm.kernels import laplace_kernel, gradient_kernel lap = tf.cast(laplace_kernel(kvec), tf.complex64) grad_x = gradient_kernel(kvec, 0) grad_y = gradient_kernel(kvec, 1) grad_z = gradient_kernel(kvec, 2) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = 8 n_block_y = 4 n_block_z = 1 halo_size = 4 # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) state = mtfpm.lpt_init_single( initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) #state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, # part_shape[1:], downsampling_factor=downsampling_factor, antialias=True,) # Here we can run our nbody final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return mtf.reduce_sum(final_field)
def lpt_prototype(mesh, initial_conditions, derivs, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ stages = np.linspace(a0, a, nsteps, endpoint=True) lap, grad_x, grad_y, grad_z = derivs klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize # Parameters of the large scales decomposition downsampling_factor = FLAGS.dsample lnc = nc // 2**downsampling_factor # fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim]) kv_hr = [kx_hr, ky_hr, kz_hr] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] initc = tf.reshape( initial_conditions, [1, n_block_x, nc // n_block_x, n_block_y, nc // n_block_y, 1, nc]) initc = tf.transpose(initc, [0, 1, 3, 5, 2, 4, 6]) field = mtf.import_tf_tensor(mesh, initc, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) # Hack to handle reshape acrosss multiple dimensions #low = mtf.reshape(low, [batch_dim, x_dim, low.shape[2], low.shape[5], z_dim]) #low = mtf.reshape(low, lr_shape) state = mtfpm.lpt_init( low, high, a0, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, k_dims, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True, ) # Here we can run our nbody final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return final_field
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape([self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append( (1.0 - mtf.to_float(mtf.replace_dimensions( input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate(self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count(layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias)
def lpt_prototype(mesh, initial_conditions, derivs, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ stages = np.linspace(a0, a, nsteps, endpoint=True) lap, grad_x, grad_y, grad_z = derivs klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize # Parameters of the large scales decomposition downsampling_factor = 0 lnc = nc // 2**downsampling_factor # ffx_dim = mtf.Dimension("fnx", nc) ffy_dim = mtf.Dimension("fny", nc) ffz_dim = mtf.Dimension("fnz", nc) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] field = mtf.import_tf_tensor(mesh, initial_conditions, shape=part_shape) state = mtfpm.lpt_init_single( field, a, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) print('TOTO', state) # Here we can run our nbody final_state = state # final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, # kv_lr, halo_size) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) final_field = mtf.reshape(final_field, [batch_dim, ffx_dim, ffy_dim, ffz_dim]) return final_field
def transformer_moe_layer_v1(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu, num_microbatches=None, token_embeddings=None, context=None): """Local heterogenous mixture of experts. See transformer_moe_layer_v1 in moe.py for a more detailed explanation for a generic moe layer. The heterogeneous mask outputted by generate_heterogeneous_expert_masks has dimension [maximum hidden size, maximum # layers, # experts] and its shape will overwrite the parameters moe_num_layers and moe_hidden_size in hparams. The layer-specific mask slice is applied at each expert layer to the activation which is [expert width, # experts]. If the heterogeneous_mask_info is None, there is no mask applied and the code is equivalent to the homogeneous case. The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Dimensions cheat sheet: B: batch dim(s) L: original sequence length M: input depth N: output depth G: number of groups S: group size E: number of experts C: expert capacity Args: inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional Tensor with shape [batch_dim(s), length_dim] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). activation: a function. num_microbatches: number of microbatches. token_embeddings: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]. These are the word embeddings for that correspond to the inputs. These can optionally be used to make routing decisions. context: a Context. Returns: outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ orig_inputs = inputs experts_dim = mtf.Dimension("experts", hparams.moe_num_experts) if hparams.moe_heterogeneous_mask_info is not None: tf.logging.info("moe_heterogeneous_mask_info: {}".format( hparams.moe_heterogeneous_mask_info)) heterogeneous_mask = generate_heterogeneous_expert_masks( hparams.moe_heterogeneous_mask_info, hparams.moe_num_experts, experts_dim, mesh=inputs.mesh, expert_width=hparams.moe_hidden_size) # overwrite depth and width with the mask maximum dimension hparams.moe_num_layers = heterogeneous_mask.shape[1].size hparams.moe_hidden_size = heterogeneous_mask.shape[0].size hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups is a multiple of the mesh dimension # over which those groups are split. batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1]) # Hack: we assume that # "outer_batch" == replication of experts # mesh_dim_size can be derived from mesh_shape and orig_batch_dim # # We then reqire num_groups to be a multiple of mesh_dim_size. if orig_inputs.shape.dims[0].name == "outer_batch": outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2] else: outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1), orig_inputs.shape.dims[0]) # Number of MoE inputs (total number of position across batch_and_length_dims # per replica. n = 1 for d in batch_and_length_dims: n *= d.size n = n // outer_batch_dim.size mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, orig_batch_dim) num_groups, group_size = moe._split_into_groups( # pylint: disable=protected-access n, hparams.moe_group_size, mesh_dim_size) # TODO(barretzoph): implementation without pylint calls? group_size_dim = mtf.Dimension("group", group_size) num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups) moe_input_dims = [ outer_batch_dim, num_groups_dim, group_size_dim, input_dim ] # OGSM Tensor inputs = mtf.reshape(inputs, moe_input_dims) # Token embeddings that can be optionally used in the router for determining # where to send tokens. if hparams.moe_word_embed_mode is not None: token_embeddings = mtf.cast( mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype) # Each sequence sends expert_capacity positions to each expert. if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min( group_size_dim.size, int((group_size_dim.size * capacity_factor) / experts_dim.size)) expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity) tf.logging.info("expert_capacity: %d" % expert_capacity) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size) if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1]) if hparams.moe_gating == "top_2": # combine_tensor, # dispatch_tensor OG`SEC Tensors # (G is generally split along mesh dim) dispatch_tensor, combine_tensor, loss = moe._top_2_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "top_n": dispatch_tensor, combine_tensor, loss = moe._top_n_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "switch": dispatch_tensor, combine_tensor, loss = moe._switch_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "ntlb": dispatch_tensor, combine_tensor, loss = moe._ntlb_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "switch_max": dispatch_tensor, combine_tensor, loss = moe._switch_max_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, token_embeddings=token_embeddings) elif hparams.moe_gating == "expert_selection": dispatch_tensor, combine_tensor, loss = moe._expert_selection_gating( # pylint: disable=protected-access inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, group_size_dim=group_size_dim, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding, name="expert_selection_gating", num_microbatches=num_microbatches, token_embeddings=token_embeddings) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) expert_inputs = mtf.einsum([inputs, dispatch_tensor], mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, input_dim ])) # Extra reshape reduces communication cost for model-parallel versions. # For model-parallel versions, this reshape causes an mtf.slice and for non- # model-parallel versions, this has no effect. d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size) expert_inputs = mtf.reshape( expert_inputs, mtf.Shape([ outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, d_model_split_dim ])) # Split over batch -> split over experts expert_inputs = mtf.reshape( expert_inputs, mtf.Shape([ outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim ])) # Pretend we have heterogeneous_mask with shape [moe_num_layers, num_experts] for layer_idx in range(hparams.moe_num_layers): with tf.variable_scope("expert_layer_{}".format(layer_idx)): res_h = 0.0 if layer_idx > 0: res_h = expert_inputs expert_inputs = transformer.sublayer_rms_norm( expert_inputs, None, context) # Now feed the expert inputs through the experts. h = mtf.layers.dense_product( expert_inputs, reduced_dims=expert_inputs.shape.dims[-1:], new_dims=[hidden_dim], expert_dims=[experts_dim], activation_functions=activation, use_bias=False, variable_dtype=variable_dtype, name="wi") # apply dropout if hparams.moe_dropout_rate != 0.0: h = mtf.dropout(h, is_training=train, keep_prob=1.0 - hparams.moe_dropout_rate) # only if heterogeneous if hparams.moe_heterogeneous_mask_info is not None: # Get mask for current layer by slicing heterogeneous mask heterogeneous_mask_slice = mtf.slice(heterogeneous_mask, layer_idx, 1, "num_expert_layers") # Get rid of the expert layers dimension. heterogeneous_mask_slice = mtf.reshape( heterogeneous_mask_slice, [ heterogeneous_mask_slice.shape[0], heterogeneous_mask_slice.shape[-1] ]) h *= mtf.cast(heterogeneous_mask_slice, h.dtype) expert_output = mtf.layers.dense(h, output_dim, expert_dims=[experts_dim], use_bias=False, reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype, name="wo") if layer_idx < (hparams.moe_num_layers - 1): expert_output = transformer.sublayer_dropout( expert_output, None, context) expert_output += res_h expert_inputs = expert_output # Extra reshape reduces communication cost for model-parallel versions. # For model-parallel versions, this reshape causes an mtf.slice and for non- # model-parallel versions, this has no effect. expert_output = mtf.reshape( expert_output, mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, d_model_split_dim ])) # Split over experts -> split over batch expert_output = mtf.reshape( expert_output, mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, output_dim, ])) moe_output_dims = moe_input_dims[:-1] + [output_dim] output = mtf.einsum([expert_output, combine_tensor], mtf.Shape(moe_output_dims)) output = mtf.reshape(output, batch_and_length_dims + [output_dim]) return output, loss * hparams.moe_loss_coef
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # # Begin simulation ## Compute initial initial conditions distributed #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) fieldvar = mtf.get_variable(mesh, 'linear', part_shape) input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape) linearop = mtf.assign(fieldvar, mtfinp) #field = fieldvar initc = fieldvar print("initc : ", initc) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( initc, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) ## x = final_field ppars, mpars, kernel = setupfnn() pwts, pbias, pmx, psx = ppars mwts, mbias, mmx, msx, mmy, msy = mpars msy, mmy = msy[0], mmy[0] print("mmy : ", mmy) size = 3 k_dims = [d.shape[0] for d in kv] k_dims = [k_dims[2], k_dims[0], k_dims[1]] tfnc, tfbs = float_to_mtf(nc * 1., mesh, scalar), float_to_mtf(bs, mesh, scalar) x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs], output_dtype=cdtype) x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype) x1d = mtf.add(x1d, -1.) x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype) x2f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype) x12 = x1 - x2 width = tf.placeholder(tf.float32, shape=()) def apply_pwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID') yy = tf.concat([y, y1, y2], axis=-1) yy = yy - pmx yy = yy / psx yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0]) yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1]) yy3 = tf.matmul(yy2, pwts[2]) + pbias[2] pmodel = tf.nn.sigmoid(width * yy3) return pmodel[..., 0] pmodel = mtf.slicewise( apply_pwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_pwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) def apply_mwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) zz = tf.concat([ tf.expand_dims(x, -1), tf.expand_dims(x1, -1), tf.expand_dims(x2, -1) ], axis=-1) zz = zz - mmx zz = zz / msx zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0]) zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1]) zz3 = tf.matmul(zz2, mwts[2]) + mbias[2] mmodel = zz3 * msy + mmy return mmodel[..., 0] mmodel = mtf.slicewise( apply_mwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_mwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) model = pmodel * mmodel mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior #k_dims = [d.shape[0] for d in kv] #k_dims = [k_dims[2], k_dims[0], k_dims[1]] k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3 # Total loss #diff = (model - mtfdata) modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype) modelsmf = mtf.cwise(cwise_fingauss, [modelf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype) #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype) #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype) ##Anneal R0 = tf.placeholder(tf.float32, shape=()) M0 = tf.placeholder(tf.float32, shape=()) off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder( tf.float32, shape=data.shape) mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape) mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape) diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0) #diff = diff / 0.25 #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one diff = (diff + mtfoff) / 0.25 def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype) var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
def _call_internal(self, context, inputs, targets=None, attributes=None, z=None): """Compute logits based on inputs (all positions in parallel). Also updates context if applicable. Args: context: a Context inputs: a Tensor targets: an optional Tensor attributes: an optional Tensor Returns:g logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim] """ mesh = inputs.mesh if self.ensemble_dim and self.ensemble_dim not in inputs.shape.dims: # Training an ensemble where all models are trained on the same examples. inputs = mtf.broadcast(inputs, [self.ensemble_dim] + inputs.shape.dims) if self.ensemble_dim not in attributes.shape.dims: attributes = mtf.broadcast(attributes, [self.ensemble_dim] + attributes.shape.dims) if targets: targets = mtf.broadcast(targets, [self.ensemble_dim] + targets.shape.dims) if "embedding" in context.shared_params: vocab_embedding = context.shared_params["embedding"] else: vocab_embedding = VocabEmbedding(mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype, name="embedding", ensemble_dim=self.ensemble_dim) x = vocab_embedding.ids_to_embedding(inputs) if self.positional_embedding: if "positional_embedding" in context.shared_params: pos_emb_var = context.shared_params["positional_embedding"] else: pos_emb_var = mtf.layers.embedding_weights( mesh, self.max_length_dim, self.model_dim, context.variable_dtype, "positional_embedding", ensemble_dim=self.ensemble_dim) if (context.length_dim is not None and context.length_dim.size > self.max_length_dim.size): message = ( "Length dimenison exceeds size of positional embedding table. " "length_dim.size > max_length_dim.size %s vs %s." % (context.length_dim, self.max_length_dim)) if context.position_is_default: # Definitely getting overflow in this case. raise ValueError(message) else: tf.logging.warning( message + " This may be OK if there are several shorter sequences packed " "together. Otherwise, the later positions will get zeros." ) if context.position_is_default: pos_emb = mtf.rename_dimension( mtf.slice(pos_emb_var, 0, context.length_dim.size, self.max_length_dim.name), self.max_length_dim.name, context.length_dim.name) else: pos_emb = mtf.gather(pos_emb_var, context.position, self.max_length_dim, output_shape=x.shape) x += pos_emb if self.attribute_embedding: if "attribute_embedding" in context.shared_params: att_emb_var = context.shared_params["attribute_embedding"] else: att_emb_var = mtf.layers.embedding_weights( mesh, self.attribute_dim, self.model_dim, context.variable_dtype, "attribute_embedding", ensemble_dim=self.ensemble_dim) att_emb = mtf.gather(att_emb_var, attributes, self.attribute_dim, output_shape=x.shape) # Addition of x and attribute # x *= LAMBDA_ATTRIBUTE * sty_emb # # Concatenation of x and attribute x_attribute = mtf.concat([x, att_emb], self.model_dim.name) x = mtf.layers.dense(x_attribute, self.model_dim, activation=None, variable_dtype=context.variable_dtype, name="comb_x_attribute") if z: z = mtf.layers.dense(z, self.model_dim, activation=None, variable_dtype=context.variable_dtype, name="z") # raise ValueError("x shape=%s , z shape=%s" % (x.shape, z.shape)) x += z x = self.layer_stack.call(context, x) if self.output_vocab_dim is None: return x if self.shared_embedding_and_softmax_weights: logits = vocab_embedding.hidden_to_logits(x) else: logits = mtf.layers.dense(x, self.output_vocab_dim, use_bias=False, variable_dtype=context.variable_dtype, reduced_dims=x.shape.dims[-1:], name="logits") if targets is not None and context.losses is not None: context.losses.append( self._compute_loss(context, logits, targets, self.output_vocab_dim)) if self.ensemble_dim: logits = reduce_ensemble_logits(logits, self.ensemble_dim, self.output_vocab_dim) return logits
def recon_model(mesh, data, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) print("\nfieldvar : \n", fieldvar) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( fieldvar, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #*nc**3 # Total loss diff = (final_field - mtfdata) R0 = tf.constant(R0) print("R0 in the recon_model : ", R0) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts # Element-wise function that applies a Fourier kernel plambda = FLAGS.plambda def _cwise_logprob(finalfield, data): galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield)) logprob = galmean.log_prob(data) return -1 * logprob cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype) cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype) final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype) chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata], output_dtype=tf.float32) # chisq = mtf.reduce_sum(chisq) ## # loss = chisq + prior def _cwise_sample(finalfield, data): galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield)) sample = galmean.sample() return sample sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata], output_dtype=tf.float32) # fields = [fieldvar, sample] metrics = [chisq, prior, loss] return fields, metrics, kv
def force(state, lr_shape, hr_shape, kvec_lr, kvec_hr, halo_size, cosmology=Planck15, downsampling_factor=2, pm_nc_factor=1, antialias=True, **kwargs): """ Estimate force on the particles given a state. Parameters: ----------- state: tensor Input state tensor of shape (3, batch_size, npart, 3) boxsize: float Size of the simulation volume (Mpc/h) TODO: check units cosmology: astropy.cosmology Cosmology object pm_nc_factor: int TODO: @modichirag please add doc """ X, P, F = state #TODO: support different factor assert pm_nc_factor == 1 lnc = lr_shape[-1].size part_shape = X.shape k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_hr = [d.shape[0] for d in kvec_hr] # Reorder the FFTs which where transposed# y,z,x k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]] # Paint the particles on the high resolution mesh field = mtf.zeros(X.mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) field = mesh_utils.cic_paint(field, X, halo_size) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size) # Split the field into low and high resolution field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) hr_field = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr) kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield, kvec_lr, r_split=0) kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield, kvec_hr, r_split=0) kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr) # Reorder the low res FFTs which where transposed# y,z,x kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]] kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]] displacement = [] for f, g in zip(kfield_lr, kfield_hr): f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4] + [ mtf.Dimension('sx_block', lnc // hr_shape[1].size), mtf.Dimension('sy_block', lnc // hr_shape[2].size), mtf.Dimension('sz_block', lnc // hr_shape[3].size) ]), name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [ halo_size // 2**downsampling_factor, halo_size // 2**downsampling_factor ], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size // 2**downsampling_factor) f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)]) f = mesh_utils.upsample(f, downsampling_factor) f = mtf.reshape(f, f.shape[:-1]) g = mesh_utils.c2r3d(g, f.shape[-3:]) high_shape = g.shape # And now we remove the large scales g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)]) _low = mesh_utils.downsample(g, downsampling_factor, antialias=antialias) g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor), g.shape) g = mtf.reshape(g, high_shape) d = mesh_utils.cic_readout(f + g, X, halo_size) displacement.append(d) # Readout the force to particle positions F = mtf.stack([d for d in displacement], "ndim", axis=4) F = F * 1.5 * cosmology.Om0 return X, P, F
def gradient_based_subword_tokenization(x, length_dim, max_subword_length=4, downsample=None, use_offsets=False, consider_chars_as_blocks=False, use_block_pos_embedding=False, share_block_kernel=False, memory_embeddings=0, context=None, block_mixing_mode=None, activation="softmax", downsample_function="mean"): """Implements GBSWT from Charformer. Args: x: a Tensor containing length_dim length_dim: a Dimension max_subword_length: integer downsample: integer. use_offsets: boolean. consider_chars_as_blocks: boolean. use_block_pos_embedding: boolean. share_block_kernel: boolean. memory_embeddings: integer. context: Context. block_mixing_mode: Str for block mixing. activation: Str for block ranking. downsample_function: Str, supports mean/linformer for now. Returns: a Tensor with the same shape as x. Raises: ValueError: if channels or depth don't match. """ # don't use this for now. del max_subword_length del memory_embeddings all_blocks = [] all_scores = [] tf.logging.info("GSW block layer") def _tile(x, n, tile_dim): # Simple tile function in MTF. return mtf.concat([x] * n, tile_dim.name) def _repeat(x, n, repeat_dim): # repeat function in MTF tmp_dim = mtf.Dimension("tmp", 1) expand_shape = mtf.Shape(x.shape.dims + [tmp_dim]) x = mtf.reshape(x, expand_shape) x = _tile(x, n, tmp_dim) output_shape = [] for dim in x.shape.dims: if dim.name == "tmp": continue if dim.name == repeat_dim.name: dim = mtf.Dimension(dim.name, dim.size * n) output_shape.append(dim) output_shape = mtf.Shape(output_shape) x = mtf.reshape(x, output_shape) return x def _combined_dim(dims): return mtf.Dimension(dims[0].name, mtf.Shape(dims).size) # compute all subword blocks # TODO(yitay): handle offsets to get all blocks if activation == "sigtanh": # one score for sigmoid tmp_dim = mtf.Dimension("block_score", 2) else: tmp_dim = mtf.Dimension("block_score", 1) model_dim = x.shape[-1] subword_blocks_width = [2, 3, 4] if consider_chars_as_blocks: subword_blocks_width += [1] if share_block_kernel: block_kernel_shape = mtf.Shape([model_dim, tmp_dim]) block_kernel = mtf.get_variable(x.mesh, "block_kernel", block_kernel_shape, initializer=None, dtype=context.variable_dtype) else: block_kernel = None for subword_len in subword_blocks_width: if use_block_pos_embedding: # this is turn off by default. It is meant to support cases like # parameterized pooling or other features. block_len_dim = mtf.Dimension(length_dim.name, subword_len) # TODO(vqtran): Consider other positional embeddings. block_pos_emb = sinusoid_positional_embedding_weights( context.mesh, block_len_dim, x.shape[-1], context.variable_dtype.activation_dtype) block_pos_emb = _repeat( block_pos_emb, math.ceil(length_dim.size / float(subword_len)), block_len_dim) if use_offsets: offset_space = subword_len else: offset_space = 1 for offsets in range(offset_space): if offsets > 0: xoff = mtf.shift(x, offsets, length_dim, wrap=False) if use_block_pos_embedding: block_pos_emb = mtf.shift(block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False) else: xoff = x tf.logging.info("SW len=%d offset=%d", subword_len, offsets) if length_dim.size % subword_len != 0: tf.logging.info("Not divisible by length") # add extra padding tokens pad_amt = int(subword_len) - int(length_dim.size % subword_len) kp = mtf.pad(xoff, [0, pad_amt], length_dim.name) else: kp = xoff if use_block_pos_embedding: kp += block_pos_emb bx = mtf.pool_tensor_1d( kp, pool_dim=kp.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(subword_len)) block_score = mtf.layers.dense(bx, [tmp_dim], use_bias=False, name="bx", reduced_dims=[model_dim], variable_dtype=None, kernel_weights=block_kernel) expand_bx = _repeat(bx, subword_len, length_dim) expand_scores = _repeat(block_score, subword_len, length_dim) if offsets > 0: # add offset. expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name) expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name) new_len = expand_bx.shape.get_dim_by_name(length_dim.name) if new_len.size < length_dim.size: pad_amt = new_len.size - length_dim.size expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name) expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name) elif new_len.size > length_dim.size: expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name) expand_scores = mtf.slice(expand_scores, 0, length_dim.size, length_dim.name) new_tmp_dim = mtf.Dimension("extra_dim", 1) expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim]) expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim]) expand_bx = mtf.reshape(expand_bx, expand_shape) expand_scores = mtf.reshape(expand_scores, expand_scores_shape) all_blocks.append(expand_bx) all_scores.append(expand_scores) all_blocks = mtf.concat(all_blocks, new_tmp_dim.name) all_scores = mtf.concat(all_scores, new_tmp_dim.name) tf.logging.info(all_blocks) new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim") combined_dim = _combined_dim([new_tmp_dim, tmp_dim]) block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim block_net = mtf.reshape(all_scores, block_net_shape) if block_mixing_mode == "score_attention": tf.logging.info("Using score attention") att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim]) tf.logging.info(block_net) att = mtf.softmax(att, reduced_dim=att.shape[-1]) block_net = mtf.einsum([att, block_net], output_shape=block_net.shape) tf.logging.info(block_net) if activation == "softmax": block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim) elif activation == "tanh": tf.logging.info("Using tanh") block_net = mtf.tanh(block_net) all_blocks = block_net * all_blocks all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim) output = all_blocks if downsample: output_length = output.shape.get_dim_by_name("length") if output_length.size % int(downsample) != 0: pad_amt = int(downsample) - int( output_length.size % int(downsample)) output = mtf.pad(output, [0, pad_amt], output_length.name) if downsample_function == "mean": output = mtf.pool_tensor_1d( output, pool_dim=output.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(downsample)) else: raise ValueError("Downsampling function not implemeneted.") return output
def force_single(state, lr_shape, hr_shape, kvec_lr, halo_size, cosmology=Planck15, pm_nc_factor=1, **kwargs): """ Estimate force on the particles given a state. Parameters: ----------- state: tensor Input state tensor of shape (3, batch_size, npart, 3) boxsize: float Size of the simulation volume (Mpc/h) TODO: check units cosmology: astropy.cosmology Cosmology object pm_nc_factor: int TODO: @modichirag please add doc """ X, P, F = state #TODO: support different factor assert pm_nc_factor == 1 lnc = lr_shape[-1].size part_shape = X.shape # Paint the particles on the high resolution mesh field = mtf.zeros(X.mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) field = mesh_utils.cic_paint(field, X, halo_size) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: field = mtf.slice(field, halo_size, block_size_dim.size, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) # Reorder the low res FFTs which where transposed# y,z,x kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]] displacement = [] for f in kfield_lr: f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4] + [ mtf.Dimension('sx_block', lnc // hr_shape[1].size), mtf.Dimension('sy_block', lnc // hr_shape[2].size), mtf.Dimension('sz_block', lnc // hr_shape[3].size) ]), name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size) d = mesh_utils.cic_readout(f, X, halo_size) displacement.append(d) # Readout the force to particle positions F = mtf.stack([d for d in displacement], "ndim", axis=4) F = F * 1.5 * cosmology.Om0 return X, P, F
def lpt_prototype(mesh, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 0 lnc = nc // 2**downsampling_factor # fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) # # Reshaping array into high resolution mesh # field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1), # [initc], # output_dtype=tf.float32, # output_shape=hr_shape, # name='my_reshape', # splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3]) # state = mtfpm.lpt_init_single( initc, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = state #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) #final_field = mtf.reshape(final_field, [batch_dim, fx_dim, fy_dim, fz_dim]) # Hack usisng custom reshape because mesh is pretty dumb final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) return initc, final_field
def _call_internal(self, context, inputs, targets=None): """Compute logits based on inputs (all positions in parallel). Also updates context if applicable. Args: context: a Context inputs: a Tensor targets: an optional Tensor Returns: logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim] """ mesh = inputs.mesh if "embedding" in context.shared_params: embedding_weights = context.shared_params["embedding"] else: embedding_weights = mtf.layers.embedding_weights( mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype, name="embedding") x = mtf.gather(embedding_weights, inputs, self.input_vocab_dim) if "positional_embedding" in context.shared_params: pos_emb_var = context.shared_params["positional_embedding"] else: pos_emb_var = mtf.layers.embedding_weights( mesh, self.max_length_dim, self.model_dim, context.variable_dtype, "positional_embedding") if context.position_is_default: pos_emb = mtf.rename_dimension( mtf.slice(pos_emb_var, 0, context.length_dim.size, self.max_length_dim.name), self.max_length_dim.name, context.length_dim.name) else: pos_emb = mtf.gather( pos_emb_var, context.position, self.max_length_dim, output_shape=x.shape) x += pos_emb x = self.layer_stack.call(context, x) if self.output_vocab_dim is None: return x if self.shared_embedding_and_softmax_weights: logits = mtf.einsum( [x * (self.model_dim.size ** -0.5), embedding_weights], reduced_dims=[self.model_dim]) else: logits = mtf.layers.dense( x, self.output_vocab_dim, use_bias=False, variable_dtype=context.variable_dtype, name="logits") if targets is not None and context.losses is not None: off_value = self.label_smoothing / self.output_vocab_dim.size on_value = 1.0 - self.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.output_vocab_dim, dtype=context.activation_dtype, on_value=on_value, off_value=off_value) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.output_vocab_dim, z_loss=self.z_loss if context.train else 0.0) weights = mtf.layers.weights_nonzero( targets, dtype=context.activation_dtype) loss = mtf.reduce_mean(loss * weights) context.losses.append(loss) return logits
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype(npdtype), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype(npdtype), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype(npdtype), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False, dtype=npdtype) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False, dtype=npdtype) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype(npdtype), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype(npdtype), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype(npdtype), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] # kvec for prior blocks prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x) prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y) prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z) kvec_pr = flowpm.kernels.fftk( [nc // n_block_x, nc // n_block_y, nc // n_block_z], symmetric=False, dtype=npdtype) kx_pr = mtf.import_tf_tensor(mesh, kvec_pr[0].squeeze().astype(npdtype), shape=[prior_sx_dim]) ky_pr = mtf.import_tf_tensor(mesh, kvec_pr[1].squeeze().astype(npdtype), shape=[prior_sy_dim]) kz_pr = mtf.import_tf_tensor(mesh, kvec_pr[2].squeeze().astype(npdtype), shape=[prior_sz_dim]) kv_pr = [ky_pr, kz_pr, kx_pr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] ## Compute initial initial conditions distributed fieldvar = mtf.get_variable(mesh, 'linear', hr_shape) input_field = tf.placeholder(data.dtype, [ batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x, nc // n_block_y, nc // n_block_z ]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape) linearop = mtf.assign(fieldvar, mtfinp) # field = fieldvar initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) # for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=initc.dtype, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) final_state = mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) else: final_state = mtfpm.lpt_init(low, high, stages[-1], kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv_pr] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 # Total loss diff = (final_field - mtfdata) R0 = tf.placeholder(tf.float32, shape=()) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype) var_grads = [ mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype) ] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0