Ejemplo n.º 1
0
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'none:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                hidden_dim = mtf.Dimension('hidden', 3)
                w = mtf.get_variable(mesh,
                                     'w',
                                     shape=[hidden_dim],
                                     initializer=tf.constant_initializer(
                                         [0.1, -0.2, -0.1]))
                x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                                 dtype=tf.float32)
                loss = mtf.reduce_mean(mtf.square(x - w))

                lr, update_ops = optimization_lib.create_optimizer(
                    loss, 0.2, 100, 10)
                self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})

                tf_update_ops = [
                    self.lowering.lowered_operation(op) for op in update_ops
                ]
                tf_update_ops.append(
                    tf.assign_add(tf.train.get_or_create_global_step(), 1))
                train_op = tf.group(tf_update_ops)

                return lr, train_op
Ejemplo n.º 2
0
def toy_model(features, mesh):
  """A toy model implemented by mesh tensorlfow."""
  batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
  io_dim = mtf.Dimension('io', FLAGS.io_size)

  master_dtype = tf.as_dtype(FLAGS.master_dtype)
  slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
  activation_dtype = tf.as_dtype(FLAGS.activation_dtype)

  x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
  x = mtf.cast(x, activation_dtype)
  h = x
  for lnum in xrange(1, FLAGS.num_hidden_layers + 2):
    if lnum + 1 == FLAGS.num_hidden_layers + 2:
      # output layer
      dim = io_dim
    elif lnum % 2 == 0:
      dim = mtf.Dimension('hidden_even', FLAGS.hidden_size)
    else:
      dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size)
    h = mtf.layers.dense(
        h, dim,
        use_bias=False,
        master_dtype=master_dtype,
        slice_dtype=slice_dtype,
        name='layer_%d' % lnum)
  y = h

  loss = mtf.reduce_mean(mtf.square(y - x))
  return y, loss
Ejemplo n.º 3
0
 def computation_fn():
     graph = mtf.Graph()
     mesh = mtf.Mesh(graph, 'my_mesh')
     mesh_shape = mtf.convert_to_shape('all:2')
     layout = 'none:all'
     mesh_devices = [''] * mesh_shape.size
     mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
         mesh_shape, mtf.convert_to_layout_rules(layout),
         mesh_devices, device_assignment)
     hidden_dim = mtf.Dimension('hidden', 3)
     w = mtf.get_variable(mesh,
                          'w',
                          shape=[hidden_dim],
                          initializer=tf.constant_initializer(
                              [0.1, -0.2, -0.1]))
     x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                      dtype=tf.float32)
     loss = mtf.reduce_mean(mtf.square(x - w))
     var_grads = mtf.gradients(
         [loss], [v.outputs[0] for v in graph.trainable_variables])
     optimizer = mtf_optimize.AdamWeightDecayOptimizer(
         learning_rate=0.2)
     update_ops = optimizer.apply_grads(var_grads,
                                        graph.trainable_variables)
     self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})
     tf_update_ops = [
         self.lowering.lowered_operation(op) for op in update_ops
     ]
     return tf.group(tf_update_ops)
Ejemplo n.º 4
0
 def _layer_norm(self, context, x, name=None):
   with tf.variable_scope(name, default_name="layer_norm"):
     scale = mtf.get_variable(
         context.mesh, "scale", mtf.Shape([context.model_dim]),
         initializer=tf.ones_initializer(),
         dtype=context.variable_dtype)
     variance = mtf.reduce_mean(mtf.square(x), reduced_dim=context.model_dim)
   return x * mtf.rsqrt(variance + self._norm_epsilon) * scale
Ejemplo n.º 5
0
def clip_by_global_norm(grads, clip_norm):
    """Clip the grads by global norm."""
    global_norm = mtf.sqrt(
        mtf.add_n(
            [mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None]))
    multiplier = clip_norm / mtf.maximum(global_norm, clip_norm)
    clipped_grads = [None if t is None else t * multiplier for t in grads]
    return clipped_grads, global_norm
def norm(x, axis=None, epsilon=1e-5):
    axis = default(axis, x.shape[-1])

    u = mtf.reduce_mean(x, reduced_dim=axis)
    s = mtf.reduce_mean(mtf.square(x - u), reduced_dim=axis)

    u = mtf.broadcast(u, x.shape)
    s = mtf.broadcast(s, x.shape)

    return (x - u) * mtf.rsqrt(s + epsilon)
Ejemplo n.º 7
0
    def apply_grad(self, grad, var):
        """See base class."""
        if grad is None:
            tf.logging.warning("Gradient is None for variable %s" % var.name)
            return []

        grad = mtf.to_float(grad)

        assignments = []

        m = mtf.get_variable(
            var.mesh,
            var.name + "/adam_m",
            var.shape,
            initializer=tf.zeros_initializer(),
            # master_dtype=self.variable_dtype.master_dtype,
            # slice_dtype=self.variable_dtype.slice_dtype,
            # activation_dtype=self.variable_dtype.activation_dtype,
            trainable=False)

        v = mtf.get_variable(
            var.mesh,
            var.name + "/adam_v",
            var.shape,
            initializer=tf.zeros_initializer(),
            # master_dtype=self.variable_dtype.master_dtype,
            # slice_dtype=self.variable_dtype.slice_dtype,
            # activation_dtype=self.variable_dtype.activation_dtype,
            trainable=False)

        # Standard Adam update.
        next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
        next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)

        update = next_m / (mtf.sqrt(next_v) + self.epsilon)

        # Just adding the square of the weights to the loss function is *not*
        # the correct way of using L2 regularization/weight decay with Adam,
        # since that will interact with the m and v parameters in strange ways.
        #
        # Instead we want to decay the weights in a manner that doesn't interact
        # with the m/v parameters. This is equivalent to adding the square
        # of the weights to the loss with plain (non-momentum) SGD.
        if self._do_use_weight_decay(var.name):
            update += mtf.to_float(var.value) * self.weight_decay_rate

        update_with_lr = self.learning_rate * update

        var_update = mtf.assign_sub(var, update_with_lr)

        assignments.extend(
            [var_update,
             mtf.assign(m, next_m),
             mtf.assign(v, next_v)])
        return assignments
Ejemplo n.º 8
0
def toy_model(features, mesh):
    """A toy model implemented by mesh tensorlfow."""
    batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
    hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size)
    io_dim = mtf.Dimension('io', FLAGS.io_size)

    x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
    h = mtf.layers.dense(x, hidden_dim, name='layer1', use_bias=False)
    y = mtf.layers.dense(h, io_dim, name='layer2', use_bias=False)

    loss = mtf.reduce_sum(mtf.square(y - x))
    return y, loss
Ejemplo n.º 9
0
def model_backbone(features, labels, mesh):
    """The model.
	Args:
		image: tf.Tensor with shape [batch, 32*32]
		labels: a tf.Tensor with shape [batch] and dtype tf.int32
		mesh: a mtf.Mesh
	Returns:
		logits: a mtf.Tensor with shape [batch, 10]
		loss: a mtf.Tensor with shape []
	"""
    id_hldr, wt_hldr = features

    batch_dim = mtf.Dimension("batch", args_opt.batch_size)
    field_dim = mtf.Dimension("field", size=39)
    vocab_dim = mtf.Dimension("vocab_size", 200000)
    embed_dim = mtf.Dimension("embed_size", 80)
    outdim = mtf.Dimension("outdim", 1)
    id_hldr = mtf.import_tf_tensor(
        mesh, tf.reshape(id_hldr, [args_opt.batch_size, field_dim.size]),
        mtf.Shape([batch_dim, field_dim]))
    wt_hldr = mtf.import_tf_tensor(
        mesh, tf.reshape(wt_hldr, [args_opt.batch_size, field_dim.size]),
        mtf.Shape([batch_dim, field_dim]))
    if args_opt.fp16:
        float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16)
        # id_hldr=mtf.cast(id_hldr,dtype=tf.int32)
        wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16)
    else:
        float16 = None

    logits, embedding_table = network[args_opt.model](id_hldr,
                                                      wt_hldr,
                                                      vocab_dim,
                                                      embed_dim,
                                                      outdim,
                                                      float16=float16)
    logits = mtf.cast(logits, dtype=tf.float32)
    embedding_table = mtf.cast(embedding_table, dtype=tf.float32)
    if labels is None:
        wide_loss = None
        deep_loss = None
    else:
        labels = mtf.import_tf_tensor(
            mesh, tf.reshape(labels, [args_opt.batch_size]),
            mtf.Shape([batch_dim]))
        wide_loss = mtf.layers.sigmoid_cross_entropy_with_logits(
            logits, labels)
        deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2
        deep_loss = mtf.reduce_mean(wide_loss) + 8e-5 * deep_loss
        wide_loss = mtf.reduce_mean(wide_loss)

    return logits, wide_loss + deep_loss
Ejemplo n.º 10
0
def layer_norm(
    x,
    dim: mtf.Dimension,
    epsilon: float = 1e-6,
    subtract_mean=True,
    use_scale=True,
    use_bias=True,
    name=None,
):
    """Layer normalization over dimension dim.

    Args:
        x: a mtf.Tensor whose shape contains dim.
        dim: a mtf.Dimension
        epsilon: a floating point number
        subtract_mean: a boolean
        use_scale: a boolean
        use_bias: a boolean
        name: a string used for tf.variable_scope.

    Returns:
        a mtf.Tensor with same shape as x.
    """
    with tf.variable_scope(name, default_name="layer_norm"):
        if subtract_mean:
            x -= mtf.reduce_mean(x, reduced_dim=dim)
        variance = mtf.reduce_mean(mtf.square(x), reduced_dim=dim)
        x *= mtf.rsqrt(variance + epsilon)
        if use_scale:
            x *= mtf.get_variable(
                x.mesh,
                "scale",
                mtf.Shape([dim]),
                initializer=tf.ones_initializer(),
                activation_dtype=x.dtype,
            )
        if use_bias:
            x += mtf.get_variable(
                x.mesh,
                "bias",
                mtf.Shape([dim]),
                initializer=tf.zeros_initializer(),
                activation_dtype=x.dtype,
            )
        return x
Ejemplo n.º 11
0
def toy_model(features, mesh):
    """A toy model implemented by mesh tensorlfow."""
    batch_dim = mtf.Dimension('batch', FLAGS.batch_size)
    io_dim = mtf.Dimension('io', FLAGS.io_size)

    master_dtype = tf.as_dtype(FLAGS.master_dtype)
    slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
    activation_dtype = tf.as_dtype(FLAGS.activation_dtype)

    x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim]))
    x = mtf.cast(x, activation_dtype)
    h = x
    for lnum in range(1, FLAGS.num_hidden_layers + 2):
        if lnum + 1 == FLAGS.num_hidden_layers + 2:
            # output layer
            dim = io_dim
        elif lnum % 2 == 0:
            dim = mtf.Dimension('hidden_even', FLAGS.hidden_size)
        else:
            dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size)
        h = mtf.layers.dense(h,
                             dim,
                             use_bias=False,
                             master_dtype=master_dtype,
                             slice_dtype=slice_dtype,
                             name='layer_%d' % lnum)
    y = h
    g = tf.train.get_global_step()
    if FLAGS.step_with_nan >= 0:
        # Trigger NaN in the forward pass, this is used for testing whether
        # MeshTensorFlow can handle occasional NaN value.
        y += mtf.import_tf_tensor(
            mesh,
            tf.divide(
                0.0,
                tf.cond(tf.equal(g, FLAGS.step_with_nan), lambda: 0.,
                        lambda: 1.)), mtf.Shape([]))

    loss = mtf.reduce_mean(mtf.square(y - x))
    return y, loss
Ejemplo n.º 12
0
def model_fn(features, labels, mode, params):
    # Get global step
    global_step = tf.train.get_global_step()

    # Construct mtf graph + mesh from params
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    layout_rules = mtf.convert_to_layout_rules(params["layout"])

    # Mesh setup
    if params["use_tpu"]:
        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape,
                                                layout_rules)
    else:
        var_placer = None
        gpu_ids = params["gpu_ids"]
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, gpu_ids)

    # Trainable variable precision
    # Store to checkpoints in master type, train in slice type, compute in activation type
    if params["precision"] == "bfloat16":
        variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.bfloat16)
    else:
        variable_dtype = mtf.VariableDType(master_dtype=tf.float32,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.float32)

    # Build mtf mesh object
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)

    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step
    features_dict = {"inputs": features, "labels": labels}
    sequence_length_dict = {
        "inputs": params["n_ctx"],
        "labels": params["n_ctx"]
    }

    params = add_mode_to_params(params, mode)
    batch_size = get_batch_size(params)

    batch_dim = mtf.Dimension("batch", batch_size)
    batch_dims = [batch_dim]
    feature_length = sequence_length_dict["inputs"]
    length_dim = mtf.Dimension("sequence", feature_length)

    mtf_features = {}
    for key, x in features_dict.items():
        if x is not None:
            feature_shape = mtf.Shape(batch_dims + [length_dim])
            if type(features_dict[key]) == dict:
                features_dict[key] = features_dict[key]["feature"]
            x = tf.cast(features_dict[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

    # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
    other_features = {}
    memory_length_dim = mtf.Dimension("memory_length", length_dim.size)

    attn_bias = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim,
        variable_dtype) if params["causal"] else None

    # Add attn_bias into mtf_features
    other_features["attn_bias"] = attn_bias

    # Define other Dimensions that we'll need inside the model
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
    # We need this because gathering when both the args have the same dimension in them breaks things
    # This dim is specifically for the weights
    # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
    embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])

    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Set up the model for prediction
        inputs = mtf_features["inputs"]
        if params["remove_partial_sequences"] is None:
            params["remove_partial_sequences"] = False

        export = params.get("export", False)

        if not export:
            mtf_samples = sample_autoregressive(
                inputs,
                other_features=other_features,
                params=params,
                variable_dtype=variable_dtype,
                remove_partial_sequences=params["remove_partial_sequences"],
                stop_at_token=params["eos_id"],
                sampling_use_entmax=params['sampling_use_entmax'])

        else:
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    mtf_samples, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)

        mtf_samples = mtf.anonymize(mtf_samples)
        inputs = mtf.anonymize(inputs)
        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
        inputs = lowering.export_to_tf_tensor(inputs)
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        predictions = {"inputs": inputs, "outputs": outputs}

        def scaffold_fn():
            return tf.train.Scaffold(
                local_init_op=tf.group(
                    tf.train.Scaffold.default_local_init_op(),
                    lowering.copy_masters_to_slices(),
                    name="mtf_local_init_op"),
                ready_op=tf.concat([
                    tf.report_uninitialized_variables(),
                    resources.report_uninitialized_resources()
                ],
                                   axis=0,
                                   name="mtf_ready_op"))

        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            scaffold_fn=scaffold_fn,
            prediction_hooks=[mtf.MtfRestoreHook(lowering)])

    # We're not predicting, so we better be training or evaluating
    assert (mode == tf.estimator.ModeKeys.TRAIN
            or mode == tf.estimator.ModeKeys.EVAL)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Gets number of microbatches per batch for serialized training
        # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
        num_microbatches = int(
            mtf_transformer.utils.serialize_num_microbatches(
                batch_dim=batch_dim,
                sequence_length=sequence_length_dict,
                mesh_shape=mesh_shape,
                layout_rules=layout_rules,
                tokens_per_microbatch_per_replica=params[
                    "tokens_per_mb_per_replica"]))
    else:
        num_microbatches = 1

    params[
        "num_microbatches"] = num_microbatches  # Add num microbatches to params

    if num_microbatches > 1:

        # For serialize_training_step we need to modify the model to output results in a dict
        def serialized_fn(mtf_features):
            if params["model"] == "GPT":
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype)
                return {
                    "logits": logits,
                    "loss": loss,
                    "loss_batch": loss_batch
                }
            else:
                raise Exception(
                    f"'{params['model']}' is not a valid model - please select from [GPT]"
                )

        # Serialize the training step - Gradients are accumulated locally and reduced once.
        var_grads, output_dict = mtf.serialize_training_step(
            mtf_features, serialized_fn, batch_dim, num_microbatches)
        loss = output_dict["loss"]
        loss_batch = output_dict["loss_batch"]
        logits = output_dict["logits"]
    else:
        # If we're not splitting into microbatches, return logits & loss as is
        if params["model"] == "GPT":
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)
        else:
            raise Exception(
                f"'{params['model']}' is not a valid model - please select from [GPT]"
            )

    # Auto layout generation
    if params["auto_layout"]:
        auto_layout(graph, mesh_shape, logits, loss)
    if params["auto_layout_and_mesh_shape"]:
        auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # In TRAIN mode, get optimizer
        if params["num_microbatches"] > 1:
            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
            # So we pass them in here
            _, update_ops, var_grads = get_optimizer(
                mesh,
                loss,
                params,
                variable_dtype=variable_dtype,
                inp_var_grads=var_grads)
        else:
            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
            _, update_ops, var_grads = get_optimizer(
                mesh, loss, params, variable_dtype=variable_dtype)
        # Log summaries to tensorboard
        mtf.scalar_summary("loss", loss)
        # Log gradients if in params
        if params["log_grads"] not in [None, False]:
            for g in var_grads:
                grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
                mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
    else:
        # For now, we can only export fully-replicated tensors.
        # This has to be done before lowering or they will not be included in the graph
        mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
        max_logits = mtf.argmax(logits, vocab_dim)
        del logits
        fully_replicated_mean_logits = mtf.anonymize(mean_logits)
        fully_replicated_max_logits = mtf.anonymize(max_logits)
        fully_replicated_loss_batch = mtf.anonymize(loss_batch)

    # Gets & prints info about no. trainable vars in the model & dimension names
    get_graph_info(graph)

    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.cast(tf_loss, tf.float32)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Use our patched version until mtf updates theirs
        host_call = create_host_call(params['model_path'])
        mtf.utils.remove_summaries()

        # Creates train_op
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(
            global_step, 1))  # Need to manually increment global_step
        tf.logging.info(f"tf_update_ops: {tf_update_ops}")
        train_op = tf.group(tf_update_ops)
    else:
        tf_mean_logits = lowering.export_to_tf_tensor(
            fully_replicated_mean_logits)
        tf_max_logits = lowering.export_to_tf_tensor(
            fully_replicated_max_logits)
        tf_loss_batch = tf.to_float(
            lowering.export_to_tf_tensor(fully_replicated_loss_batch))

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            # Set up the checkpoint server and return the TPUEstimatorSpec
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                params["model_path"],
                save_steps=params["steps_per_checkpoint"],
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                host_call=host_call,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            # Evaluation metrics
            def _perplexity(loss):
                perplexity = tf.exp(loss)
                return tf.metrics.mean(perplexity)

            def _bits_per_byte(loss):
                bpb = loss * (0.29335 / math.log(2))
                return tf.metrics.mean(bpb)

            def _metric_fn(tf_mean_logits, tf_loss_batch):
                mean_logits = tf.metrics.mean(tf_mean_logits)
                loss = tf.reduce_mean(tf_loss_batch)
                perp = _perplexity(loss)
                bpb = _bits_per_byte(loss)
                return {
                    "mean_logits": mean_logits,
                    "perplexity": perp,
                    "bits per byte": bpb
                }

            def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
                eos_token = params["eos_id"]
                answer_positions = tf.where(
                    tf.math.not_equal(labels, eos_token))

                correct_answers = tf.gather_nd(
                    tf.math.equal(tf_max_logits, labels), answer_positions)
                accuracy = tf.metrics.mean(tf.cast(correct_answers,
                                                   tf.float32))

                # I guess tf_loss_batch has z_loss and maybe other stuff added to it
                # so maybe this should be calculated separately in the future
                answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
                log_perplexity = tf.metrics.mean(answer_loss)

                return {
                    "lambada_acc": accuracy,
                    "lambada_log_ppl": log_perplexity
                }

            eval_task = params["eval_task"]
            if eval_task == "lambada":
                eval_metrics = (_lambada_metric_fn,
                                [labels, tf_max_logits, tf_loss_batch])
            else:
                eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
Ejemplo n.º 13
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    #
    # Begin simulation

    ## Compute initial initial conditions distributed
    #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    fieldvar = mtf.get_variable(mesh, 'linear', part_shape)
    input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    linearop = mtf.assign(fieldvar, mtfinp)

    #field = fieldvar
    initc = fieldvar

    print("initc : ", initc)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])
    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    print("mmy : ", mmy)
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    width = tf.placeholder(tf.float32, shape=())

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(width * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    #k_dims = [d.shape[0] for d in kv]
    #k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)
    #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype)
    #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype)
    #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    R0 = tf.placeholder(tf.float32, shape=())
    M0 = tf.placeholder(tf.float32, shape=())
    off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder(
        tf.float32, shape=data.shape)
    mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
    mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0)
    #diff = diff / 0.25
    #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one
    diff = (diff + mtfoff) / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype)
    var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
Ejemplo n.º 14
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 2
    lnc = nc // 2**downsampling_factor

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype(npdtype),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype(npdtype),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype(npdtype),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc],
                                  symmetric=False,
                                  dtype=npdtype)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)

    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False,
                                  dtype=npdtype)
    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype(npdtype),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype(npdtype),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype(npdtype),
                                 shape=[padded_sz_dim])
    kv_hr = [ky_hr, kz_hr, kx_hr]

    # kvec for prior blocks
    prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x)
    prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y)
    prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z)

    kvec_pr = flowpm.kernels.fftk(
        [nc // n_block_x, nc // n_block_y, nc // n_block_z],
        symmetric=False,
        dtype=npdtype)
    kx_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[0].squeeze().astype(npdtype),
                                 shape=[prior_sx_dim])
    ky_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[1].squeeze().astype(npdtype),
                                 shape=[prior_sy_dim])
    kz_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[2].squeeze().astype(npdtype),
                                 shape=[prior_sz_dim])
    kv_pr = [ky_pr, kz_pr, kx_pr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, x_dim, y_dim, z_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    ## Compute initial initial conditions distributed

    fieldvar = mtf.get_variable(mesh, 'linear', hr_shape)
    input_field = tf.placeholder(data.dtype, [
        batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x,
        nc // n_block_y, nc // n_block_z
    ])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape)
    linearop = mtf.assign(fieldvar, mtfinp)
    #
    field = fieldvar
    initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                          output_dtype=tf.float32,
                          output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                          name='my_dumb_reshape',
                          splittable_dims=part_shape[:-1] + hr_shape[:4])

    #
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=initc.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init(low,
                               high,
                               0.1,
                               kv_lr,
                               kv_hr,
                               halo_size,
                               hr_shape,
                               lr_shape,
                               part_shape[1:],
                               downsampling_factor=downsampling_factor,
                               antialias=True)

        final_state = mtfpm.nbody(state,
                                  stages,
                                  lr_shape,
                                  hr_shape,
                                  kv_lr,
                                  kv_hr,
                                  halo_size,
                                  downsampling_factor=downsampling_factor)
    else:
        final_state = mtfpm.lpt_init(low,
                                     high,
                                     stages[-1],
                                     kv_lr,
                                     kv_hr,
                                     halo_size,
                                     hr_shape,
                                     lr_shape,
                                     part_shape[1:],
                                     downsampling_factor=downsampling_factor,
                                     antialias=True)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv_pr]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.placeholder(tf.float32, shape=())

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype)
    var_grads = [
        mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype)
    ]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0
Ejemplo n.º 15
0
 def normalize(x):
     scale = layer_norm_vars.pop(0)
     variance = mtf.reduce_mean(mtf.square(x),
                                reduced_dim=self.model_dim)
     return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale
Ejemplo n.º 16
0
def norm(x, axis, epsilon=1e-8):
    x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u")
    s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s")
    return x * mtf.rsqrt(s + epsilon)
Ejemplo n.º 17
0
def attention(q,
              k,
              v,
              memory_length_dim,
              key_dim,
              value_dim,
              bias=None,
              dropout_rate=0.0,
              dropout_broadcast_dims=None,
              extra_logit=None,
              context=None,
              float32_logits=True,
              z_loss_coeff=None):
  """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    context: an optional Transformer.Context
    float32_logits: a boolean - if True, then compute logits in float32 to avoid
      numerical issues with bfloat16
    z_loss_coeff: a float, if z_loss_coeff is not None then add an auxiliary
      loss to push the attention logits closer to zero. This helps to stabilize
      model training.

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
  orig_q_shape = q.shape
  q, k, v, bias = maybe_reshape_attention_input_for_2d_sharding(
      context, q, k, v, bias, [key_dim, value_dim])
  if float32_logits:
    k = mtf.cast(k, tf.float32)
    q = mtf.cast(q, tf.float32)
  logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim])
  if bias is not None:
    logits += mtf.cast(bias, logits.dtype)

  # Adds auxiliary z-loss to push the attention logits towards zero.
  if z_loss_coeff is not None and context.train:
    tf.logging.info("attention z_loss being added: {}".format(
        tf.get_variable_scope().name))
    log_z = mtf.reduce_logsumexp(logits, memory_length_dim)
    z_loss = mtf.square(log_z) * mtf.cast(context.nonpadding, log_z.dtype)
    z_loss = mtf.reduce_mean(z_loss)
    if context.num_microbatches and context.num_microbatches > 1:
      tf.logging.info(
          "Dividing attention z-loss loss by num_microbatches={}".format(
              context.num_microbatches))
      z_loss /= context.num_microbatches
    if context.train:
      mtf.scalar_summary("attention_z_loss", z_loss)
    z_loss *= z_loss_coeff
    context.losses.append(mtf.cast(z_loss, v.dtype))

  weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
  weights = mtf.cast(weights, v.dtype)
  weights = mtf.dropout(
      weights, context.train, 1.0 - dropout_rate,
      noise_shape=weights.shape - dropout_broadcast_dims)
  outputs_shape = q.shape - key_dim + value_dim
  outputs = mtf.einsum([weights, v], outputs_shape)
  outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim)
  return outputs
Ejemplo n.º 18
0
def recon_model(mesh,
                datasm,
                rsdfactor,
                M0,
                R0,
                width,
                off,
                istd,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]

    #
    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))

    ##
    state = mtfpm.lpt_init_single(
        fieldvar,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr,
                                     halo_size)

    final_field = mtf.zeros(mesh, shape=part_shape)
    final_field = mcomp.cic_paint_fr(final_field,
                                     final_state,
                                     output_shape=part_shape,
                                     hr_shape=hr_shape,
                                     halo_size=halo_size,
                                     splittables=splittables,
                                     mesh=mesh)

    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(tf.constant(width) * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    ##RSD below
    hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables,
                              mesh)
    mstate = mpm.mtf_indices(hr_field.mesh,
                             shape=part_shape[1:],
                             dtype=tf.float32)
    X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate],
                   output_shape=[batch_dim] + mstate.shape[:])

    massf = mesh_utils.r2c3d(final_field, k_dims, dtype=cdtype)
    masssmf = mtf.cwise(cwise_fingauss,
                        [massf, float_to_mtf(R1, mesh, scalar)] + kv +
                        [tfnc, tfbs],
                        output_dtype=cdtype)
    masssm = mesh_utils.c2r3d(masssmf, final_field.shape[-3:], dtype=dtype)
    masssm = masssm + 1e-5
    imasssm = mtf.pow(x, -1.)

    vzweights = final_state[1]
    vzweights = mtf.slicewise(lambda x: x[:, :, :, :, -1], [vzweights],
                              output_dtype=tf.float32,
                              output_shape=vzweights.shape[:-1],
                              name='get_vz',
                              splittable_dims=vzweights.shape[1:-1])
    print("weights : ", vzweights)

    momz = mtf.zeros(mesh, shape=part_shape)
    momz = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, \
                              halo_size=halo_size, splittables=splittables, mesh=mesh, weights=vzweights)
    momzf = mesh_utils.r2c3d(momz, k_dims, dtype=cdtype)
    momzsmf = mtf.cwise(cwise_fingauss,
                        [momzf, float_to_mtf(R1, mesh, scalar)] + kv +
                        [tfnc, tfbs],
                        output_dtype=cdtype)
    momzsm = mesh_utils.c2r3d(momzsmf, momz.shape[-3:], dtype=dtype)

    #Shift
    velzsm = mtf.divide(momzsm, masssm)
    vz = mcomp.cic_readout_fr(velzsm, [X],
                              hr_shape=hr_shape,
                              halo_size=halo_size,
                              splittables=splittables,
                              mesh=mesh)
    vz = mtf.multiply(vz, rsdfactor)
    print("vz : ", vz)

    Xrsd = mtf.slicewise(lambda x, vz: x + tf.stack(
        [tf.zeros_like(vz), tf.zeros_like(vz), vz], 4), [X, vzweights],
                         output_dtype=tf.float32,
                         output_shape=X.shape,
                         name='add_vz',
                         splittable_dims=X.shape[1:-1])
    print(Xrsd)
    modelread = mcomp.cic_readout_fr(model, [X],
                                     hr_shape=hr_shape,
                                     halo_size=halo_size,
                                     splittables=splittables,
                                     mesh=mesh)
    modelrsd = mtf.zeros(mesh, shape=part_shape)
    modelrsd = mcomp.cic_paint_fr(modelrsd, [Xrsd], output_shape=part_shape, hr_shape=hr_shape, \
                                  halo_size=halo_size, splittables=splittables, mesh=mesh, weights=modelread)

    model = modelrsd
    print(modelrsd)

    #Likelihood and prior here
    mtfdatasm = mtf.import_tf_tensor(mesh,
                                     tf.convert_to_tensor(datasm),
                                     shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    M0 = tf.constant(M0)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdatasm + M0)
    if off is not None:
        mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
        diff = diff + mtfoff
    if istd is not None:
        mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
        diff = (diff + mtfoff
                ) * mtfistd  #For some reason, doing things wrong this one
    else:
        diff = diff / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    fields = [fieldvar, final_field, model]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
Ejemplo n.º 19
0
def recon_model(mesh,
                data,
                R0,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))
    print("\nfieldvar : \n", fieldvar)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            fieldvar,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3  #*nc**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.constant(R0)
    print("R0 in the recon_model : ", R0)

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    # Element-wise function that applies a Fourier kernel
    plambda = FLAGS.plambda

    def _cwise_logprob(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        logprob = galmean.log_prob(data)
        return -1 * logprob

    cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype)
    cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype)
    final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype)
    chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata],
                      output_dtype=tf.float32)  #
    chisq = mtf.reduce_sum(chisq)
    ##    #

    loss = chisq + prior

    def _cwise_sample(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        sample = galmean.sample()
        return sample

    sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata],
                       output_dtype=tf.float32)  #
    fields = [fieldvar, sample]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
Ejemplo n.º 20
0
 def normalize(x):
   scale = layer_norm_vars.pop(0)
   variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
   return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale
Ejemplo n.º 21
0
    else:
        float16 = None
    result, embedding_table = widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim,
                                       outdim, float16)

    # label = mtf.reshape(label,new_shape=[batch_dim, outdim])
    # output = mtf.layers.softmax_cross_entropy_with_logits(result, label,vocab_dim=outdim)
    # result = mtf.sigmoid(result)
    # result = -(label*mtf.log(result)+(1-label)*mtf.log(1-result))
    # result = mtf.reduce_sum(result)
    result = mtf.cast(result, dtype=tf.float32)
    embedding_table = mtf.cast(embedding_table, dtype=tf.float32)
    probability = mtf.sigmoid(result)
    result = mtf.layers.sigmoid_cross_entropy_with_logits(result, label)
    wide_loss = mtf.reduce_mean(result)
    deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2
    deep_loss = mtf.reduce_mean(result) + 8e-5 * deep_loss

    # print("========",global_step)
    devices = ["gpu:0"]
    mesh_shape = [("all_processors", 1)]
    layout_rules = [("dim1", "all_processors")]
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, devices)
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    wide_loss = lowering.export_to_tf_tensor(wide_loss)
    # result = lowering.export_to_tf_tensor(result)
    # predict = lowering.export_to_tf_tensor(probability)
    # predict = tf.where(predict>0.5,tf.ones_like(predict),tf.zeros_like(predict))
    deep_loss = lowering.export_to_tf_tensor(deep_loss)
    print(wide_loss)
Ejemplo n.º 22
0
        def model(self, mesh, x, y, params):
            # x :: [batch, io, vocab]

            if params["precision"] == "bfloat16":
                dtype = tf.bfloat16
                # master has type float32, slice and activation have type bfloat16
                variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16,
                                                   tf.bfloat16)
            else:
                dtype = tf.float32
                # master, slice and activate have all float16
                variable_dtype = mtf.VariableDType(tf.float32, tf.float32,
                                                   tf.float32)

            # Build the actual model
            batch_dim = mtf.Dimension("batch", params["batch_size"])
            vocab_dim = mtf.Dimension("vocab", params["vocab_size"])
            io_dim = mtf.Dimension("sequence", params["io"])
            io_chan_dim = mtf.Dimension("io", params["io_channels"])

            # from input to mtf
            x = mtf.import_tf_tensor(mesh, x,
                                     mtf.Shape([batch_dim, io_dim, vocab_dim]))

            # Embeddings
            with tf.variable_scope(scope="toy", default_name="seq2seq"):
                with tf.variable_scope("embeddings"):
                    # Perform embedding lookup on the word ids.
                    embedding_table = mtf.get_variable(
                        mesh,
                        "word_embeddings",
                        mtf.Shape([vocab_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )

                    word_embedding_output = mtf.gather(
                        embedding_table,
                        x,
                        dim=vocab_dim,
                        output_shape=io_chan_dim)

                    # Add positional embeddings and token type embeddings, then layer
                    # normalize and perform dropout.
                    embedding_output = word_embedding_output

                    pos_embedding = mtf.get_variable(
                        mesh,
                        "pos_embeddings",
                        mtf.Shape([io_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )
                    embedding_output = self.normalize(embedding_output)
                    embedding_output = mtf.dropout(
                        embedding_output,
                        keep_prob=1.0 - self.config.layer_output_dropout_prob,
                    )

                # shift token by pos embeddings
                x = word_embedding_output + pos_embedding
                x = mtf.cast(x, variable_dtype.activation_dtype)

                h = x
                for lnum in range(1, self.num_hidden_layers + 2):
                    if lnum + 1 == self.num_hidden_layers + 2:
                        # output layer
                        dim = io_dim
                    elif lnum % 2 == 0:
                        dim = mtf.Dimension("hidden_even", io_chan_dim)
                    else:
                        dim = mtf.Dimension("hidden_odd", io_chan_dim)
                        h = mtf.layers.dense(
                            h,
                            dim,
                            use_bias=False,
                            master_dtype=variable_dtype.master_dtype,
                            slice_dtype=variable_dtype.slice_dtype,
                            name="layer_%d" % lnum,
                        )

                prediction = h
                # project back to token dimensions

                # compute the mean quare loss between the input and the output
                loss = mtf.reduce_mean(mtf.square(y - prediction))
                return prediction, loss
Ejemplo n.º 23
0
def recon_model(mesh,
                data,
                bparams,
                ipkerror,
                R0,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    b1, b2, bs2 = bparams
    kerror, perror = ipkerror[0].astype(np.float32), ipkerror[1].astype(
        np.float32)

    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('..//data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('..//data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])
    pke_dim = mtf.Dimension("epk", len(perror))
    pkerror = mtf.import_tf_tensor(mesh,
                                   perror.astype(npdtype),
                                   shape=[pke_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]
    #
    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))

    state = mtfpm.lpt_init_single(
        fieldvar,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr,
                                     halo_size)

    # paint the field
    final_field = mtf.zeros(mesh, shape=part_shape)
    final_field = mcomp.cic_paint_fr(final_field, final_state, part_shape,
                                     hr_shape, halo_size, splittables, mesh)

    ##
    #Get the fields for bias
    hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables,
                              mesh)
    mstate = mpm.mtf_indices(hr_field.mesh,
                             shape=part_shape[1:],
                             dtype=tf.float32)
    X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate],
                   output_shape=[batch_dim] + mstate.shape[:])
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    tfnc, tfbs = cswisef.float_to_mtf(nc * 1., mesh,
                                      scalar), cswisef.float_to_mtf(
                                          bs, mesh, scalar)

    #
    initc = fieldvar
    d0 = initc - mtf.reduce_mean(initc)
    #
    d2 = initc * initc
    d2 = d2 - mtf.reduce_mean(d2)
    #
    cfield = mesh_utils.r2c3d(d0, k_dims_pr, dtype=cdtype)
    shearfield = mtf.zeros(mesh, shape=part_shape)
    shearfield = shear(shearfield, cfield, kv, tfnc, tfbs)
    s2 = shearfield - mtf.reduce_mean(shearfield)

    dread = mcomp.cic_readout_fr(d0, [X],
                                 hr_shape=hr_shape,
                                 halo_size=halo_size,
                                 splittables=splittables,
                                 mesh=mesh)
    d2read = mcomp.cic_readout_fr(d2, [X],
                                  hr_shape=hr_shape,
                                  halo_size=halo_size,
                                  splittables=splittables,
                                  mesh=mesh)
    s2read = mcomp.cic_readout_fr(s2, [X],
                                  hr_shape=hr_shape,
                                  halo_size=halo_size,
                                  splittables=splittables,
                                  mesh=mesh)

    ed, ed2, es2 = mtf.zeros(mesh, shape=part_shape), mtf.zeros(
        mesh, shape=part_shape), mtf.zeros(mesh, shape=part_shape)
    ed = mcomp.cic_paint_fr(ed,
                            final_state,
                            output_shape=part_shape,
                            hr_shape=hr_shape,
                            halo_size=halo_size,
                            splittables=splittables,
                            mesh=mesh,
                            weights=dread)
    ed2 = mcomp.cic_paint_fr(ed2,
                             final_state,
                             output_shape=part_shape,
                             hr_shape=hr_shape,
                             halo_size=halo_size,
                             splittables=splittables,
                             mesh=mesh,
                             weights=d2read)
    es2 = mcomp.cic_paint_fr(es2,
                             final_state,
                             output_shape=part_shape,
                             hr_shape=hr_shape,
                             halo_size=halo_size,
                             splittables=splittables,
                             mesh=mesh,
                             weights=s2read)

    model = ed * b1 + ed2 * b2 + es2 * bs2
    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)
    diff = model - mtfdata

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3  #* nc**3

    # Total loss
    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)

    def _cwise_diff(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(x=kk,
                                                 x_ref_min=kerror.min(),
                                                 x_ref_max=kerror.max(),
                                                 y_ref=pk)
        priormesh = tf.reshape(pkmesh, kshape)
        priormesh = tf.cast(priormesh**0.5, kfield.dtype)
        return kfield / priormesh

    cdiff = mtf.cwise(_cwise_diff, [cdiff, pkerror] + kv, output_dtype=cdtype)

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    fields = [fieldvar, final_field, model]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv