def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) lr, update_ops = optimization_lib.create_optimizer( loss, 0.2, 100, 10) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append( tf.assign_add(tf.train.get_or_create_global_step(), 1)) train_op = tf.group(tf_update_ops) return lr, train_op
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) io_dim = mtf.Dimension('io', FLAGS.io_size) master_dtype = tf.as_dtype(FLAGS.master_dtype) slice_dtype = tf.as_dtype(FLAGS.slice_dtype) activation_dtype = tf.as_dtype(FLAGS.activation_dtype) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) x = mtf.cast(x, activation_dtype) h = x for lnum in xrange(1, FLAGS.num_hidden_layers + 2): if lnum + 1 == FLAGS.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension('hidden_even', FLAGS.hidden_size) else: dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=master_dtype, slice_dtype=slice_dtype, name='layer_%d' % lnum) y = h loss = mtf.reduce_mean(mtf.square(y - x)) return y, loss
def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=0.2) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] return tf.group(tf_update_ops)
def _layer_norm(self, context, x, name=None): with tf.variable_scope(name, default_name="layer_norm"): scale = mtf.get_variable( context.mesh, "scale", mtf.Shape([context.model_dim]), initializer=tf.ones_initializer(), dtype=context.variable_dtype) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=context.model_dim) return x * mtf.rsqrt(variance + self._norm_epsilon) * scale
def clip_by_global_norm(grads, clip_norm): """Clip the grads by global norm.""" global_norm = mtf.sqrt( mtf.add_n( [mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None])) multiplier = clip_norm / mtf.maximum(global_norm, clip_norm) clipped_grads = [None if t is None else t * multiplier for t in grads] return clipped_grads, global_norm
def norm(x, axis=None, epsilon=1e-5): axis = default(axis, x.shape[-1]) u = mtf.reduce_mean(x, reduced_dim=axis) s = mtf.reduce_mean(mtf.square(x - u), reduced_dim=axis) u = mtf.broadcast(u, x.shape) s = mtf.broadcast(s, x.shape) return (x - u) * mtf.rsqrt(s + epsilon)
def apply_grad(self, grad, var): """See base class.""" if grad is None: tf.logging.warning("Gradient is None for variable %s" % var.name) return [] grad = mtf.to_float(grad) assignments = [] m = mtf.get_variable( var.mesh, var.name + "/adam_m", var.shape, initializer=tf.zeros_initializer(), # master_dtype=self.variable_dtype.master_dtype, # slice_dtype=self.variable_dtype.slice_dtype, # activation_dtype=self.variable_dtype.activation_dtype, trainable=False) v = mtf.get_variable( var.mesh, var.name + "/adam_v", var.shape, initializer=tf.zeros_initializer(), # master_dtype=self.variable_dtype.master_dtype, # slice_dtype=self.variable_dtype.slice_dtype, # activation_dtype=self.variable_dtype.activation_dtype, trainable=False) # Standard Adam update. next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) update = next_m / (mtf.sqrt(next_v) + self.epsilon) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(var.name): update += mtf.to_float(var.value) * self.weight_decay_rate update_with_lr = self.learning_rate * update var_update = mtf.assign_sub(var, update_with_lr) assignments.extend( [var_update, mtf.assign(m, next_m), mtf.assign(v, next_v)]) return assignments
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size) io_dim = mtf.Dimension('io', FLAGS.io_size) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) h = mtf.layers.dense(x, hidden_dim, name='layer1', use_bias=False) y = mtf.layers.dense(h, io_dim, name='layer2', use_bias=False) loss = mtf.reduce_sum(mtf.square(y - x)) return y, loss
def model_backbone(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ id_hldr, wt_hldr = features batch_dim = mtf.Dimension("batch", args_opt.batch_size) field_dim = mtf.Dimension("field", size=39) vocab_dim = mtf.Dimension("vocab_size", 200000) embed_dim = mtf.Dimension("embed_size", 80) outdim = mtf.Dimension("outdim", 1) id_hldr = mtf.import_tf_tensor( mesh, tf.reshape(id_hldr, [args_opt.batch_size, field_dim.size]), mtf.Shape([batch_dim, field_dim])) wt_hldr = mtf.import_tf_tensor( mesh, tf.reshape(wt_hldr, [args_opt.batch_size, field_dim.size]), mtf.Shape([batch_dim, field_dim])) if args_opt.fp16: float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16) # id_hldr=mtf.cast(id_hldr,dtype=tf.int32) wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16) else: float16 = None logits, embedding_table = network[args_opt.model](id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=float16) logits = mtf.cast(logits, dtype=tf.float32) embedding_table = mtf.cast(embedding_table, dtype=tf.float32) if labels is None: wide_loss = None deep_loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [args_opt.batch_size]), mtf.Shape([batch_dim])) wide_loss = mtf.layers.sigmoid_cross_entropy_with_logits( logits, labels) deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2 deep_loss = mtf.reduce_mean(wide_loss) + 8e-5 * deep_loss wide_loss = mtf.reduce_mean(wide_loss) return logits, wide_loss + deep_loss
def layer_norm( x, dim: mtf.Dimension, epsilon: float = 1e-6, subtract_mean=True, use_scale=True, use_bias=True, name=None, ): """Layer normalization over dimension dim. Args: x: a mtf.Tensor whose shape contains dim. dim: a mtf.Dimension epsilon: a floating point number subtract_mean: a boolean use_scale: a boolean use_bias: a boolean name: a string used for tf.variable_scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name, default_name="layer_norm"): if subtract_mean: x -= mtf.reduce_mean(x, reduced_dim=dim) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=dim) x *= mtf.rsqrt(variance + epsilon) if use_scale: x *= mtf.get_variable( x.mesh, "scale", mtf.Shape([dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype, ) if use_bias: x += mtf.get_variable( x.mesh, "bias", mtf.Shape([dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype, ) return x
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) io_dim = mtf.Dimension('io', FLAGS.io_size) master_dtype = tf.as_dtype(FLAGS.master_dtype) slice_dtype = tf.as_dtype(FLAGS.slice_dtype) activation_dtype = tf.as_dtype(FLAGS.activation_dtype) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) x = mtf.cast(x, activation_dtype) h = x for lnum in range(1, FLAGS.num_hidden_layers + 2): if lnum + 1 == FLAGS.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension('hidden_even', FLAGS.hidden_size) else: dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size) h = mtf.layers.dense(h, dim, use_bias=False, master_dtype=master_dtype, slice_dtype=slice_dtype, name='layer_%d' % lnum) y = h g = tf.train.get_global_step() if FLAGS.step_with_nan >= 0: # Trigger NaN in the forward pass, this is used for testing whether # MeshTensorFlow can handle occasional NaN value. y += mtf.import_tf_tensor( mesh, tf.divide( 0.0, tf.cond(tf.equal(g, FLAGS.step_with_nan), lambda: 0., lambda: 1.)), mtf.Shape([])) loss = mtf.reduce_mean(mtf.square(y - x)) return y, loss
def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) else: var_placer = None gpu_ids = params["gpu_ids"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, gpu_ids) # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) # Build mtf mesh object mesh = mtf.Mesh(graph, "my_mesh", var_placer) # Build mtf_features & seq length dict for getting number of microbatches # We need to pack inputs into a dict to pass into serialize_training_step features_dict = {"inputs": features, "labels": labels} sequence_length_dict = { "inputs": params["n_ctx"], "labels": params["n_ctx"] } params = add_mode_to_params(params, mode) batch_size = get_batch_size(params) batch_dim = mtf.Dimension("batch", batch_size) batch_dims = [batch_dim] feature_length = sequence_length_dict["inputs"] length_dim = mtf.Dimension("sequence", feature_length) mtf_features = {} for key, x in features_dict.items(): if x is not None: feature_shape = mtf.Shape(batch_dims + [length_dim]) if type(features_dict[key]) == dict: features_dict[key] = features_dict[key]["feature"] x = tf.cast(features_dict[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model other_features = {} memory_length_dim = mtf.Dimension("memory_length", length_dim.size) attn_bias = biasmask_attn_weights( mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None # Add attn_bias into mtf_features other_features["attn_bias"] = attn_bias # Define other Dimensions that we'll need inside the model embd_dim = mtf.Dimension("embd", params["n_embd"]) vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) # We need this because gathering when both the args have the same dimension in them breaks things # This dim is specifically for the weights # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) other_features["embd_dim"] = embd_dim other_features["vocab_dim"] = vocab_dim other_features["embed_sequence_dim"] = embed_sequence_dim other_features["memory_length_dim"] = memory_length_dim if mode == tf.estimator.ModeKeys.PREDICT: # Set up the model for prediction inputs = mtf_features["inputs"] if params["remove_partial_sequences"] is None: params["remove_partial_sequences"] = False export = params.get("export", False) if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax']) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): mtf_samples, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) # We're not predicting, so we better be training or evaluating assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) if mode == tf.estimator.ModeKeys.TRAIN: # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int( mtf_transformer.utils.serialize_num_microbatches( batch_dim=batch_dim, sequence_length=sequence_length_dict, mesh_shape=mesh_shape, layout_rules=layout_rules, tokens_per_microbatch_per_replica=params[ "tokens_per_mb_per_replica"])) else: num_microbatches = 1 params[ "num_microbatches"] = num_microbatches # Add num microbatches to params if num_microbatches > 1: # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) return { "logits": logits, "loss": loss, "loss_batch": loss_batch } else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Serialize the training step - Gradients are accumulated locally and reduced once. var_grads, output_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = output_dict["loss"] loss_batch = output_dict["loss_batch"] logits = output_dict["logits"] else: # If we're not splitting into microbatches, return logits & loss as is if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): logits, loss, loss_batch = gpt2.model( mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) else: raise Exception( f"'{params['model']}' is not a valid model - please select from [GPT]" ) # Auto layout generation if params["auto_layout"]: auto_layout(graph, mesh_shape, logits, loss) if params["auto_layout_and_mesh_shape"]: auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) if mode == tf.estimator.ModeKeys.TRAIN: # In TRAIN mode, get optimizer if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer( mesh, loss, params, variable_dtype=variable_dtype) # Log summaries to tensorboard mtf.scalar_summary("loss", loss) # Log gradients if in params if params["log_grads"] not in [None, False]: for g in var_grads: grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) else: # For now, we can only export fully-replicated tensors. # This has to be done before lowering or they will not be included in the graph mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) max_logits = mtf.argmax(logits, vocab_dim) del logits fully_replicated_mean_logits = mtf.anonymize(mean_logits) fully_replicated_max_logits = mtf.anonymize(max_logits) fully_replicated_loss_batch = mtf.anonymize(loss_batch) # Gets & prints info about no. trainable vars in the model & dimension names get_graph_info(graph) # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if mode == tf.estimator.ModeKeys.TRAIN: # Use our patched version until mtf updates theirs host_call = create_host_call(params['model_path']) mtf.utils.remove_summaries() # Creates train_op tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add( global_step, 1)) # Need to manually increment global_step tf.logging.info(f"tf_update_ops: {tf_update_ops}") train_op = tf.group(tf_update_ops) else: tf_mean_logits = lowering.export_to_tf_tensor( fully_replicated_mean_logits) tf_max_logits = lowering.export_to_tf_tensor( fully_replicated_max_logits) tf_loss_batch = tf.to_float( lowering.export_to_tf_tensor(fully_replicated_loss_batch)) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: # Set up the checkpoint server and return the TPUEstimatorSpec saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( params["model_path"], save_steps=params["steps_per_checkpoint"], saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, host_call=host_call, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: # Evaluation metrics def _perplexity(loss): perplexity = tf.exp(loss) return tf.metrics.mean(perplexity) def _bits_per_byte(loss): bpb = loss * (0.29335 / math.log(2)) return tf.metrics.mean(bpb) def _metric_fn(tf_mean_logits, tf_loss_batch): mean_logits = tf.metrics.mean(tf_mean_logits) loss = tf.reduce_mean(tf_loss_batch) perp = _perplexity(loss) bpb = _bits_per_byte(loss) return { "mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb } def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): eos_token = params["eos_id"] answer_positions = tf.where( tf.math.not_equal(labels, eos_token)) correct_answers = tf.gather_nd( tf.math.equal(tf_max_logits, labels), answer_positions) accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) # I guess tf_loss_batch has z_loss and maybe other stuff added to it # so maybe this should be calculated separately in the future answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) log_perplexity = tf.metrics.mean(answer_loss) return { "lambada_acc": accuracy, "lambada_log_ppl": log_perplexity } eval_task = params["eval_task"] if eval_task == "lambada": eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) else: eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # # Begin simulation ## Compute initial initial conditions distributed #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) fieldvar = mtf.get_variable(mesh, 'linear', part_shape) input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape) linearop = mtf.assign(fieldvar, mtfinp) #field = fieldvar initc = fieldvar print("initc : ", initc) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( initc, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) ## x = final_field ppars, mpars, kernel = setupfnn() pwts, pbias, pmx, psx = ppars mwts, mbias, mmx, msx, mmy, msy = mpars msy, mmy = msy[0], mmy[0] print("mmy : ", mmy) size = 3 k_dims = [d.shape[0] for d in kv] k_dims = [k_dims[2], k_dims[0], k_dims[1]] tfnc, tfbs = float_to_mtf(nc * 1., mesh, scalar), float_to_mtf(bs, mesh, scalar) x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs], output_dtype=cdtype) x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype) x1d = mtf.add(x1d, -1.) x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype) x2f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype) x12 = x1 - x2 width = tf.placeholder(tf.float32, shape=()) def apply_pwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID') yy = tf.concat([y, y1, y2], axis=-1) yy = yy - pmx yy = yy / psx yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0]) yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1]) yy3 = tf.matmul(yy2, pwts[2]) + pbias[2] pmodel = tf.nn.sigmoid(width * yy3) return pmodel[..., 0] pmodel = mtf.slicewise( apply_pwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_pwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) def apply_mwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) zz = tf.concat([ tf.expand_dims(x, -1), tf.expand_dims(x1, -1), tf.expand_dims(x2, -1) ], axis=-1) zz = zz - mmx zz = zz / msx zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0]) zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1]) zz3 = tf.matmul(zz2, mwts[2]) + mbias[2] mmodel = zz3 * msy + mmy return mmodel[..., 0] mmodel = mtf.slicewise( apply_mwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_mwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) model = pmodel * mmodel mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior #k_dims = [d.shape[0] for d in kv] #k_dims = [k_dims[2], k_dims[0], k_dims[1]] k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3 # Total loss #diff = (model - mtfdata) modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype) modelsmf = mtf.cwise(cwise_fingauss, [modelf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype) #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype) #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype) ##Anneal R0 = tf.placeholder(tf.float32, shape=()) M0 = tf.placeholder(tf.float32, shape=()) off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder( tf.float32, shape=data.shape) mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape) mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape) diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0) #diff = diff / 0.25 #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one diff = (diff + mtfoff) / 0.25 def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype) var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype(npdtype), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype(npdtype), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype(npdtype), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False, dtype=npdtype) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False, dtype=npdtype) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype(npdtype), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype(npdtype), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype(npdtype), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] # kvec for prior blocks prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x) prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y) prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z) kvec_pr = flowpm.kernels.fftk( [nc // n_block_x, nc // n_block_y, nc // n_block_z], symmetric=False, dtype=npdtype) kx_pr = mtf.import_tf_tensor(mesh, kvec_pr[0].squeeze().astype(npdtype), shape=[prior_sx_dim]) ky_pr = mtf.import_tf_tensor(mesh, kvec_pr[1].squeeze().astype(npdtype), shape=[prior_sy_dim]) kz_pr = mtf.import_tf_tensor(mesh, kvec_pr[2].squeeze().astype(npdtype), shape=[prior_sz_dim]) kv_pr = [ky_pr, kz_pr, kx_pr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] ## Compute initial initial conditions distributed fieldvar = mtf.get_variable(mesh, 'linear', hr_shape) input_field = tf.placeholder(data.dtype, [ batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x, nc // n_block_y, nc // n_block_z ]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape) linearop = mtf.assign(fieldvar, mtfinp) # field = fieldvar initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) # for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=initc.dtype, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) final_state = mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) else: final_state = mtfpm.lpt_init(low, high, stages[-1], kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv_pr] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 # Total loss diff = (final_field - mtfdata) R0 = tf.placeholder(tf.float32, shape=()) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype) var_grads = [ mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype) ] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0
def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale
def norm(x, axis, epsilon=1e-8): x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u") s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s") return x * mtf.rsqrt(s + epsilon)
def attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, context=None, float32_logits=True, z_loss_coeff=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor context: an optional Transformer.Context float32_logits: a boolean - if True, then compute logits in float32 to avoid numerical issues with bfloat16 z_loss_coeff: a float, if z_loss_coeff is not None then add an auxiliary loss to push the attention logits closer to zero. This helps to stabilize model training. Returns: Tensor with shape q.shape - key_dim + value_dim """ orig_q_shape = q.shape q, k, v, bias = maybe_reshape_attention_input_for_2d_sharding( context, q, k, v, bias, [key_dim, value_dim]) if float32_logits: k = mtf.cast(k, tf.float32) q = mtf.cast(q, tf.float32) logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += mtf.cast(bias, logits.dtype) # Adds auxiliary z-loss to push the attention logits towards zero. if z_loss_coeff is not None and context.train: tf.logging.info("attention z_loss being added: {}".format( tf.get_variable_scope().name)) log_z = mtf.reduce_logsumexp(logits, memory_length_dim) z_loss = mtf.square(log_z) * mtf.cast(context.nonpadding, log_z.dtype) z_loss = mtf.reduce_mean(z_loss) if context.num_microbatches and context.num_microbatches > 1: tf.logging.info( "Dividing attention z-loss loss by num_microbatches={}".format( context.num_microbatches)) z_loss /= context.num_microbatches if context.train: mtf.scalar_summary("attention_z_loss", z_loss) z_loss *= z_loss_coeff context.losses.append(mtf.cast(z_loss, v.dtype)) weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.cast(weights, v.dtype) weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim) return outputs
def recon_model(mesh, datasm, rsdfactor, M0, R0, width, off, istd, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3] # # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) ## state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) final_field = mtf.zeros(mesh, shape=part_shape) final_field = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) ## x = final_field ppars, mpars, kernel = setupfnn() pwts, pbias, pmx, psx = ppars mwts, mbias, mmx, msx, mmy, msy = mpars msy, mmy = msy[0], mmy[0] size = 3 k_dims = [d.shape[0] for d in kv] k_dims = [k_dims[2], k_dims[0], k_dims[1]] tfnc, tfbs = float_to_mtf(nc * 1., mesh, scalar), float_to_mtf(bs, mesh, scalar) x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs], output_dtype=cdtype) x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype) x1d = mtf.add(x1d, -1.) x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype) x2f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype) x12 = x1 - x2 def apply_pwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID') yy = tf.concat([y, y1, y2], axis=-1) yy = yy - pmx yy = yy / psx yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0]) yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1]) yy3 = tf.matmul(yy2, pwts[2]) + pbias[2] pmodel = tf.nn.sigmoid(tf.constant(width) * yy3) return pmodel[..., 0] pmodel = mtf.slicewise( apply_pwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_pwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) def apply_mwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) zz = tf.concat([ tf.expand_dims(x, -1), tf.expand_dims(x1, -1), tf.expand_dims(x2, -1) ], axis=-1) zz = zz - mmx zz = zz / msx zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0]) zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1]) zz3 = tf.matmul(zz2, mwts[2]) + mbias[2] mmodel = zz3 * msy + mmy return mmodel[..., 0] mmodel = mtf.slicewise( apply_mwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_mwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) model = pmodel * mmodel ##RSD below hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables, mesh) mstate = mpm.mtf_indices(hr_field.mesh, shape=part_shape[1:], dtype=tf.float32) X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) massf = mesh_utils.r2c3d(final_field, k_dims, dtype=cdtype) masssmf = mtf.cwise(cwise_fingauss, [massf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) masssm = mesh_utils.c2r3d(masssmf, final_field.shape[-3:], dtype=dtype) masssm = masssm + 1e-5 imasssm = mtf.pow(x, -1.) vzweights = final_state[1] vzweights = mtf.slicewise(lambda x: x[:, :, :, :, -1], [vzweights], output_dtype=tf.float32, output_shape=vzweights.shape[:-1], name='get_vz', splittable_dims=vzweights.shape[1:-1]) print("weights : ", vzweights) momz = mtf.zeros(mesh, shape=part_shape) momz = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, \ halo_size=halo_size, splittables=splittables, mesh=mesh, weights=vzweights) momzf = mesh_utils.r2c3d(momz, k_dims, dtype=cdtype) momzsmf = mtf.cwise(cwise_fingauss, [momzf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) momzsm = mesh_utils.c2r3d(momzsmf, momz.shape[-3:], dtype=dtype) #Shift velzsm = mtf.divide(momzsm, masssm) vz = mcomp.cic_readout_fr(velzsm, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) vz = mtf.multiply(vz, rsdfactor) print("vz : ", vz) Xrsd = mtf.slicewise(lambda x, vz: x + tf.stack( [tf.zeros_like(vz), tf.zeros_like(vz), vz], 4), [X, vzweights], output_dtype=tf.float32, output_shape=X.shape, name='add_vz', splittable_dims=X.shape[1:-1]) print(Xrsd) modelread = mcomp.cic_readout_fr(model, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) modelrsd = mtf.zeros(mesh, shape=part_shape) modelrsd = mcomp.cic_paint_fr(modelrsd, [Xrsd], output_shape=part_shape, hr_shape=hr_shape, \ halo_size=halo_size, splittables=splittables, mesh=mesh, weights=modelread) model = modelrsd print(modelrsd) #Likelihood and prior here mtfdatasm = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(datasm), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3 # Total loss #diff = (model - mtfdata) modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype) modelsmf = mtf.cwise(cwise_fingauss, [modelf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype) ##Anneal M0 = tf.constant(M0) diff = mtf.log(modelsm + M0) - mtf.log(mtfdatasm + M0) if off is not None: mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape) diff = diff + mtfoff if istd is not None: mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape) diff = (diff + mtfoff ) * mtfistd #For some reason, doing things wrong this one else: diff = diff / 0.25 def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior fields = [fieldvar, final_field, model] metrics = [chisq, prior, loss] return fields, metrics, kv
def recon_model(mesh, data, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) print("\nfieldvar : \n", fieldvar) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( fieldvar, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #*nc**3 # Total loss diff = (final_field - mtfdata) R0 = tf.constant(R0) print("R0 in the recon_model : ", R0) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts # Element-wise function that applies a Fourier kernel plambda = FLAGS.plambda def _cwise_logprob(finalfield, data): galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield)) logprob = galmean.log_prob(data) return -1 * logprob cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype) cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype) final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype) chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata], output_dtype=tf.float32) # chisq = mtf.reduce_sum(chisq) ## # loss = chisq + prior def _cwise_sample(finalfield, data): galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield)) sample = galmean.sample() return sample sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata], output_dtype=tf.float32) # fields = [fieldvar, sample] metrics = [chisq, prior, loss] return fields, metrics, kv
else: float16 = None result, embedding_table = widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16) # label = mtf.reshape(label,new_shape=[batch_dim, outdim]) # output = mtf.layers.softmax_cross_entropy_with_logits(result, label,vocab_dim=outdim) # result = mtf.sigmoid(result) # result = -(label*mtf.log(result)+(1-label)*mtf.log(1-result)) # result = mtf.reduce_sum(result) result = mtf.cast(result, dtype=tf.float32) embedding_table = mtf.cast(embedding_table, dtype=tf.float32) probability = mtf.sigmoid(result) result = mtf.layers.sigmoid_cross_entropy_with_logits(result, label) wide_loss = mtf.reduce_mean(result) deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2 deep_loss = mtf.reduce_mean(result) + 8e-5 * deep_loss # print("========",global_step) devices = ["gpu:0"] mesh_shape = [("all_processors", 1)] layout_rules = [("dim1", "all_processors")] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, devices) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) wide_loss = lowering.export_to_tf_tensor(wide_loss) # result = lowering.export_to_tf_tensor(result) # predict = lowering.export_to_tf_tensor(probability) # predict = tf.where(predict>0.5,tf.ones_like(predict),tf.zeros_like(predict)) deep_loss = lowering.export_to_tf_tensor(deep_loss) print(wide_loss)
def model(self, mesh, x, y, params): # x :: [batch, io, vocab] if params["precision"] == "bfloat16": dtype = tf.bfloat16 # master has type float32, slice and activation have type bfloat16 variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16, tf.bfloat16) else: dtype = tf.float32 # master, slice and activate have all float16 variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) # Build the actual model batch_dim = mtf.Dimension("batch", params["batch_size"]) vocab_dim = mtf.Dimension("vocab", params["vocab_size"]) io_dim = mtf.Dimension("sequence", params["io"]) io_chan_dim = mtf.Dimension("io", params["io_channels"]) # from input to mtf x = mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, io_dim, vocab_dim])) # Embeddings with tf.variable_scope(scope="toy", default_name="seq2seq"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([vocab_dim, io_chan_dim]), initializer=self.embedding_initializer, ) word_embedding_output = mtf.gather( embedding_table, x, dim=vocab_dim, output_shape=io_chan_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. embedding_output = word_embedding_output pos_embedding = mtf.get_variable( mesh, "pos_embeddings", mtf.Shape([io_dim, io_chan_dim]), initializer=self.embedding_initializer, ) embedding_output = self.normalize(embedding_output) embedding_output = mtf.dropout( embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob, ) # shift token by pos embeddings x = word_embedding_output + pos_embedding x = mtf.cast(x, variable_dtype.activation_dtype) h = x for lnum in range(1, self.num_hidden_layers + 2): if lnum + 1 == self.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension("hidden_even", io_chan_dim) else: dim = mtf.Dimension("hidden_odd", io_chan_dim) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, name="layer_%d" % lnum, ) prediction = h # project back to token dimensions # compute the mean quare loss between the input and the output loss = mtf.reduce_mean(mtf.square(y - prediction)) return prediction, loss
def recon_model(mesh, data, bparams, ipkerror, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ b1, b2, bs2 = bparams kerror, perror = ipkerror[0].astype(np.float32), ipkerror[1].astype( np.float32) if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('..//data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('..//data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) pke_dim = mtf.Dimension("epk", len(perror)) pkerror = mtf.import_tf_tensor(mesh, perror.astype(npdtype), shape=[pke_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3] # # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) # paint the field final_field = mtf.zeros(mesh, shape=part_shape) final_field = mcomp.cic_paint_fr(final_field, final_state, part_shape, hr_shape, halo_size, splittables, mesh) ## #Get the fields for bias hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables, mesh) mstate = mpm.mtf_indices(hr_field.mesh, shape=part_shape[1:], dtype=tf.float32) X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] tfnc, tfbs = cswisef.float_to_mtf(nc * 1., mesh, scalar), cswisef.float_to_mtf( bs, mesh, scalar) # initc = fieldvar d0 = initc - mtf.reduce_mean(initc) # d2 = initc * initc d2 = d2 - mtf.reduce_mean(d2) # cfield = mesh_utils.r2c3d(d0, k_dims_pr, dtype=cdtype) shearfield = mtf.zeros(mesh, shape=part_shape) shearfield = shear(shearfield, cfield, kv, tfnc, tfbs) s2 = shearfield - mtf.reduce_mean(shearfield) dread = mcomp.cic_readout_fr(d0, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) d2read = mcomp.cic_readout_fr(d2, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) s2read = mcomp.cic_readout_fr(s2, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) ed, ed2, es2 = mtf.zeros(mesh, shape=part_shape), mtf.zeros( mesh, shape=part_shape), mtf.zeros(mesh, shape=part_shape) ed = mcomp.cic_paint_fr(ed, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=dread) ed2 = mcomp.cic_paint_fr(ed2, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=d2read) es2 = mcomp.cic_paint_fr(es2, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=s2read) model = ed * b1 + ed2 * b2 + es2 * bs2 mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) diff = model - mtfdata # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #* nc**3 # Total loss cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) def _cwise_diff(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid(x=kk, x_ref_min=kerror.min(), x_ref_max=kerror.max(), y_ref=pk) priormesh = tf.reshape(pkmesh, kshape) priormesh = tf.cast(priormesh**0.5, kfield.dtype) return kfield / priormesh cdiff = mtf.cwise(_cwise_diff, [cdiff, pkerror] + kv, output_dtype=cdtype) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior fields = [fieldvar, final_field, model] metrics = [chisq, prior, loss] return fields, metrics, kv