Exemple #1
0
def _pad_channels_dim(tensor, size):
  channels_dim = tensor.shape.dims[-1]
  if channels_dim.size > size:
    raise ValueError("Cannot pad to size of {} when the original size "
                     "of {} is bigger".format(size, channels_dim.size))
  elif channels_dim.size == size:
    return tensor
  else:
    return mtf.pad(tensor, [0, size - channels_dim.size], channels_dim.name)
Exemple #2
0
def halo_reduce(x, blocks_dim, block_size_dim, halo_size, wrap=True):
    """Reduce each block with the margins of adjacent blocks.

  Get left and right blocks_dim and sum overlap along block_size_dim.
  Only supports halo size smaller than block_size/2

  Args:
    x: a Tensor.
    blocks_dim: a Dimension in x.shape
    block_size_dim: a Dimension in x.shape
    halo_size: an integer
    wrap: a boolean

  Returns:
    a Tensor with the same shape as x, other than in block_size_dim, whose
    size is increased by 2*halo_size.
  """
    if halo_size == 0:
        return x
    block_size = block_size_dim.size
    assert halo_size <= block_size // 2

    left_margin = mtf.slice(x, 0, 2 * halo_size, block_size_dim.name)
    right_margin = mtf.slice(x, block_size_dim.size - 2 * halo_size,
                             2 * halo_size, block_size_dim.name)
    center = mtf.slice(x, 2 * halo_size, block_size_dim.size - 4 * halo_size,
                       block_size_dim.name)

    # Perform halo exchange sum margins
    left = mtf.shift(right_margin, 1, blocks_dim, wrap) + left_margin
    right = mtf.shift(left_margin, -1, blocks_dim, wrap) + right_margin

    # Recompose block
    left = mtf.pad(left, [0, block_size_dim.size - 2 * halo_size],
                   block_size_dim.name)
    right = mtf.pad(right, [block_size_dim.size - 2 * halo_size, 0],
                    block_size_dim.name)
    center = mtf.pad(center, [2 * halo_size, 2 * halo_size],
                     block_size_dim.name)
    x = left + center + right
    return x
Exemple #3
0
def lpt_init_single(lr_field, a0, kvec_lr, halo_size, lr_shape, hr_shape, part_shape, antialias=True, order=1, post_filtering=True, cosmology=Planck15):
  a = a0
  batch_dim = lr_field.shape[0]
  lnc = lr_shape[-1].size

  # Create particles on the high resolution grid
  mstate = mesh_ops.mtf_indices(lr_field.mesh, shape=part_shape, dtype=tf.float32)
  X = mtf.einsum([mtf.ones(lr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:])


  k_dims_lr = [d.shape[0] for d in kvec_lr]
  k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]

  lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)

  grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)

  # Reorder the low res FFTs which where transposed# y,z,x
  grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]]


  displacement = []
  for f in grad_kfield_lr:
    f = mesh_utils.c2r3d(f, lr_shape[-3:])
    f = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                      [f],
                      output_dtype=tf.float32,
                      output_shape=mtf.Shape(hr_shape[0:4]+[
                        mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                        mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                        mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                      name='my_reshape',
                      splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3])

    for block_size_dim in hr_shape[-3:]:
      f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
      f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size)
    d =  mesh_utils.cic_readout(f, X, halo_size)
    displacement.append(d)
  # Readout to particle positions
  displacement = mtf.stack([ d for d in displacement],"ndim",axis=4)

  pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0)
  DX = pt.D1(a) * displacement
  P = (a ** 2 * pt.f1(a) * pt.E(a)) * DX
  F = (a ** 2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX
  # TODO: Implement 2nd order LPT

  # Moves the particles according to displacement
  X = X + DX

  return X, P, F
def fr_to_hr(field, hr_shape, halo_size, splittables, mesh):
    # Reshaping array into high resolution mesh
    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                          [field],
                          output_dtype=field.dtype,
                          output_shape=hr_shape,
                          name='my_reshape',
                          splittable_dims=splittables)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    return field
Exemple #5
0
def test_sampling():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    batch_dim = mtf.Dimension("batch", 1)
    sequence_dim = mtf.Dimension("sequence", 1)

    inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
    inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)

    # create mask

    seq_len = params["n_ctx"]
    num_mem_kv = params.get('num_mem_kv', 0)
    length_dim = mtf.Dimension('sequence', seq_len)
    memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
    embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])

    other_features = {}

    other_features["attn_bias"] = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    params["mode"] = "predict"

    with not_raises(Exception):
        samples = sample_autoregressive(
            inputs,
            other_features=other_features,
            params=params,
            variable_dtype=mtf.VariableDType(),
            remove_partial_sequences=params["remove_partial_sequences"],
            stop_at_token=params["eos_id"],
            sampling_use_entmax=True)

        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        samples = lowering.export_to_tf_tensor(samples)
def cic_readout_fr(field, state, hr_shape, halo_size, splittables, mesh):
    '''readout from at the position from state on a field of batch+3D tensor'''
    lnc = field.shape[-1].size
    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                          [field],
                          output_dtype=tf.float32,
                          output_shape=mtf.Shape(hr_shape[0:4]+[
                              mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                              mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                              mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                          name='my_reshape',
                          splittable_dims=splittables)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    #Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    read = mesh_utils.cic_readout(field, state[0], halo_size)
    return read
def cic_paint_fr(field, state, output_shape, hr_shape, halo_size, splittables, mesh, weights=None):
    '''paint the position from state to a field of batch+3D tensor
    Ops performed :
    - reshape to hr_shape
    - pad
    - paint
    - halo_reduce
    - slice to remove pad
    - reshape to output_shape
    '''
    lnc = field.shape[-1].size
    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                          [field],
                          output_dtype=tf.float32,
                          output_shape=mtf.Shape(hr_shape[0:4]+[
                              mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                              mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                              mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                          name='my_reshape',
                          splittable_dims=splittables)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    field = mesh_utils.cic_paint(field, state[0], halo_size, weights)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        field = mtf.slice(field, halo_size, block_size_dim.size, block_size_dim.name)

    field = mtf.slicewise(lambda x: x[:,0,0,0],
                        [field],
                          output_dtype=field.dtype,
                          output_shape=output_shape,
                          name='my_dumb_reshape',
                          splittable_dims=splittables)
    return field
Exemple #8
0
def force_single(state,
                 lr_shape,
                 hr_shape,
                 kvec_lr,
                 halo_size,
                 cosmology=Planck15,
                 pm_nc_factor=1,
                 **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        field = mtf.slice(field, halo_size, block_size_dim.size,
                          block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)

    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]

    displacement = []
    for f in kfield_lr:
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size)
        d = mesh_utils.cic_readout(f, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F
Exemple #9
0
def force(state,
          lr_shape,
          hr_shape,
          kvec_lr,
          kvec_hr,
          halo_size,
          cosmology=Planck15,
          downsampling_factor=2,
          pm_nc_factor=1,
          antialias=True,
          **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape
    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_hr = [d.shape[0] for d in kvec_hr]
    # Reorder the FFTs which where transposed# y,z,x
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]]

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)

    # Split the field into low and high resolution
    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)
    low = mtf.reshape(low, low.shape[:-1])
    hr_field = mtf.reshape(high, high.shape[:-1])
    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)
    hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr)

    kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield,
                                                    kvec_lr,
                                                    r_split=0)
    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)
    kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield,
                                                    kvec_hr,
                                                    r_split=0)
    kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]
    kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]]

    displacement = []
    for f, g in zip(kfield_lr, kfield_hr):
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])
        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [
                halo_size // 2**downsampling_factor,
                halo_size // 2**downsampling_factor
            ], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim,
                                     halo_size // 2**downsampling_factor)
        f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)])
        f = mesh_utils.upsample(f, downsampling_factor)
        f = mtf.reshape(f, f.shape[:-1])

        g = mesh_utils.c2r3d(g, f.shape[-3:])
        high_shape = g.shape
        # And now we remove the large scales
        g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)])
        _low = mesh_utils.downsample(g,
                                     downsampling_factor,
                                     antialias=antialias)
        g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor),
                            g.shape)
        g = mtf.reshape(g, high_shape)

        d = mesh_utils.cic_readout(f + g, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F
Exemple #10
0
def nbody_prototype(mesh,
                    infield=False,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin, shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    ## Compute initial initial conditions distributed
    input_field = tf.placeholder(dtype, [batch_size, nc, nc, nc])
    if infield:
        initc = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    else:
        initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            initc,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return initc, final_field, input_field
Exemple #11
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.
        Args:
          features: dictionary where keys are strings like "inputs" and "targets"
            and the values are the actual values of "inputs". See TPUEstimator's
            docs for more information
          labels: ignored argument
          mode: a tf.estimator.ModeKeys
          params: dictionary containing the key "context"
          config: ignored argument
        Returns:
          a TPUEstimatorSpec
        """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu and "context" in params:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            # deprecated mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
                mesh_shape.to_integer_list, physical_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            # deprecated mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        mtf_features = {}
        for key, x in features.items():
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            # Some auxiliary features may have been generated in packing.
            # The names of these new features are of the form
            #   "<original_feature_name>_<suffix>", e.g. "inputs_segmentation".
            #   We look up the lengths based on the original feature name, without
            #   the "_<suffix>".
            feature_length = sequence_length[key.split("_")[0]]
            length_dim = mtf.Dimension("length", feature_length)
            ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)]
                             if ensemble_inputs else [])
            feature_shape = mtf.Shape(ensemble_dims +
                                      [outer_batch_dim, batch_dim, length_dim])
            x = tf.cast(features[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            if not use_tpu:
                tf.logging.info("feature %s : %s" % (key, x))
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=10)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)
            if key == "targets" or key == "codeprefixedtargets" or key == "controlcode":
                anon_targets = mtf.anonymize(mtf_features[key])

        if mode == tf.estimator.ModeKeys.PREDICT:

            def _feature_shape(key):
                feature_length = sequence_length[key.split("_")[0]]
                return mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", feature_length)
                ])

            mtf_features = {
                k: mtf.reshape(v, _feature_shape(k))
                for k, v in six.iteritems(mtf_features)
            }
            inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if has_partial_sequences:
                controlcodes = mtf_features["controlcode"]
            else:
                controlcodes = None

            if predict_fn:
                mtf_samples = predict_fn(model=transformer_model,
                                         features=mtf_features,
                                         variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Unitransformer):
                # pad so that there is enough room for the targets
                inputs = mtf.pad(inputs, [0, sequence_length["targets"]],
                                 length_dim.name)
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs,
                    variable_dtype=get_variable_dtype(),
                    remove_partial_sequences=True)
            elif isinstance(transformer_model, Bitransformer_ll):
                mtf_samples = transformer_model.decode(
                    inputs,
                    attributes=attributes,
                    controlcodes=controlcodes,
                    has_partial_sequences=has_partial_sequences,
                    remove_partial_sequences=remove_partial_sequences,
                    variable_dtype=get_variable_dtype())  #
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            inputs = mtf.anonymize(inputs)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            inputs = lowering.export_to_tf_tensor(inputs)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"inputs": inputs, "outputs": outputs}

            # When exporting a model, we need to communicate to TF-Serving that
            # master variables need to be copied to their slave slice variables.
            # Estimator uses a Scaffold's "local_init_op" for this purpose, so we
            # augment the default "local_init_op" here.
            #
            # The "ready_op" is also constructed here to ensure the variables
            # initialized by "local_init_op" are the same ones checked by "ready_op".
            #
            # WARNING: Any variables created outside of this model_fn()
            # (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor
            # checked by these ops.
            def scaffold_fn():
                return tf.train.Scaffold(
                    local_init_op=tf.group(
                        tf.train.Scaffold.default_local_init_op(),
                        lowering.copy_masters_to_slices(),
                        name="mtf_local_init_op"),
                    ready_op=tf.concat([
                        tf.report_uninitialized_variables(),
                        resources.report_uninitialized_resources()
                    ],
                                       axis=0,
                                       name="mtf_ready_op"))

            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                scaffold_fn=scaffold_fn,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        assert (mode == tf.estimator.ModeKeys.TRAIN
                or mode == tf.estimator.ModeKeys.EVAL)

        def logits_and_loss(mtf_features):
            """Compute logits and loss.
            Args:
              mtf_features: a dictionary
            Returns:
              logits: a mtf.Tensor
              loss: a mtf.Tensor
            """
            if model_type == "lm":  # TOTRY Adapt that to our case
                if "inputs" in mtf_features:
                    mtf_features = _dynamic_text2self(mtf_features)
                _, _, length_dim = mtf_features["targets"].shape
                inputs = mtf.shift(mtf_features["targets"],
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
            else:
                inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if control_codes:
                codeprefixedtargets = mtf_features["codeprefixedtargets"]
            else:
                codeprefixedtargets = None

            if isinstance(transformer_model, transformer.Unitransformer):
                position_kwargs = dict(
                    sequence_id=mtf_features.get("targets_segmentation", None),
                    position=mtf_features.get("targets_position", None),
                )
            elif isinstance(transformer_model, transformer.Bitransformer
                            ) or model_type == "bi_student_teacher":
                if control_codes:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "codeprefixedtargets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "codeprefixedtargets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "codeprefixedtargets_position", None),
                    )
                else:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "targets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
            else:
                raise ValueError("unrecognized class")

            if isinstance(transformer_model, Bitransformer_ll):
                if cycle_consistency_loss:
                    logits_ae, l_ae = transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    if has_partial_sequences:
                        controlcodes = mtf_features["controlcode"]
                    else:
                        controlcodes = None

                    with gin.config_scope('training'):
                        mtf_samples = transformer_model.decode(
                            inputs,
                            attributes=attributes,
                            controlcodes=controlcodes,
                            has_partial_sequences=has_partial_sequences,
                            remove_partial_sequences=remove_partial_sequences,
                            variable_dtype=get_variable_dtype())
                        # mtf_samples = mtf.anonymize(mtf_samples)
                    outputs = mtf_samples

                    logits_cycle, l_cycle = transformer_model.call_simple(
                        inputs=outputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle
                    return logits_cycle, loss_ae_cycle
                else:
                    return transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)
            else:
                return transformer_model.call_simple(
                    inputs=inputs,
                    targets=mtf_features["targets"],
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    num_microbatches=num_microbatches,
                    **position_kwargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            num_microbatches = serialize_num_microbatches(
                batch_dim, sequence_length, mesh_shape, layout_rules)
            if num_microbatches > 1:

                def serialized_fn(mtf_features):
                    return {
                        "loss":
                        (logits_and_loss(mtf_features)[1] / num_microbatches)
                    }

                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, serialized_fn, batch_dim, num_microbatches)
                loss = loss_dict["loss"]
            else:
                loss = logits_and_loss(mtf_features)[1]
                var_grads = mtf.gradients(
                    [loss], [v.outputs[0] for v in graph.trainable_variables])

            if tpu_summaries:
                mtf.scalar_summary("loss", loss)

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            if isinstance(variable_filter, str):
                pattern = re.compile(variable_filter)
                variable_filter_fn = lambda v: pattern.search(v.name)
            elif variable_filter is None:
                variable_filter_fn = lambda v: True
            elif callable(variable_filter):
                variable_filter_fn = variable_filter
            else:
                raise ValueError(
                    "variable_filter must be None, a string, or a callable function"
                )
            trainable_vars = [
                v for v in graph.trainable_variables if variable_filter_fn(v)
            ]
            trainable_var_grads = [
                g for g, v in zip(var_grads, graph.trainable_variables)
                if variable_filter_fn(v)
            ]
            if len(trainable_vars) != len(graph.trainable_variables):
                tf.logging.info("Variables being trained:")
                tf.logging.info([v.name for v in trainable_vars])
                tf.logging.info("Variables not being trained:")
                tf.logging.info([
                    v.name for v in graph.trainable_variables
                    if not variable_filter_fn(v)
                ])

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                trainable_var_grads, trainable_vars)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.cast(tf_loss, tf.float32)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            if tpu_summaries:
                # has to be outside of
                # with mtf.utils.outside_all_rewrites()
                host_call = mtf.utils.create_host_call(model_dir)
                mtf.utils.remove_summaries()
            else:
                host_call = None

            with mtf.utils.outside_all_rewrites():

                if init_checkpoint:
                    ckpt_vars = {
                        v
                        for v, _ in tf.train.list_variables(init_checkpoint)
                    }
                    global_vars = {v.op.name for v in tf.global_variables()}
                    restore_vars = ckpt_vars.intersection(global_vars)
                    tf.logging.info("Initializing variables from %s:",
                                    init_checkpoint)
                    tf.logging.debug("\n".join(sorted(restore_vars)))
                    tf.logging.info("Variables in %s but not in graph:",
                                    init_checkpoint)
                    tf.logging.info("\n".join(sorted(ckpt_vars - global_vars)))
                    tf.logging.info("Variables in graph but not in %s:",
                                    init_checkpoint)
                    tf.logging.info("\n".join(sorted(global_vars - ckpt_vars)))
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  {v: v
                                                   for v in restore_vars})

                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir,
                    summarize_config=True,
                    include_step_in_filename=False)

                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
        elif mode == tf.estimator.ModeKeys.EVAL:
            logits, loss = logits_and_loss(mtf_features)
            anon_logits = mtf.anonymize(logits)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
            tf_loss = tf.cast(tf_loss, tf.float32)
            tf_logits = tf.cast(lowering.export_to_tf_tensor(anon_logits),
                                tf.float32)

            def simple_metrics(logits, labels):
                """Simple metrics for teacher-forced eval."""
                weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
                xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)
                predictions = tf.cast(tf.argmax(logits, axis=-1), labels.dtype)
                token_correct = tf.cast(tf.equal(predictions, labels),
                                        tf.float32) * weights
                sequence_correct = tf.to_float(
                    tf.equal(tf.reduce_sum(token_correct, -1),
                             tf.reduce_sum(weights, -1)))
                sequence_weights = tf.to_float(
                    tf.not_equal(tf.reduce_sum(weights, -1), 0))
                return {
                    "neg_log_perplexity":
                    tf.metrics.mean(-xent, weights),
                    "token_accuracy":
                    tf.metrics.mean(token_correct, weights),
                    "sequence_accuracy":
                    tf.metrics.mean(sequence_correct, sequence_weights)
                }

            labels = lowering.export_to_tf_tensor(anon_targets)
            eval_metrics = (simple_metrics, [tf_logits, labels])
            with mtf.utils.outside_all_rewrites():
                restore_hook = mtf.MtfRestoreHook(lowering)
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
Exemple #12
0
def nbody_fn(mesh,
             klin,
             plin,
             nc=FLAGS.nc,
             bs=FLAGS.box_size,
             batch_size=FLAGS.batch_size,
             a0=FLAGS.a0,
             a=FLAGS.af,
             nsteps=FLAGS.nsteps,
             dtype=tf.float32):
  """ Pyramid N-body function
  """
  stages = np.linspace(a0, a, nsteps, endpoint=True)

  # Define the named dimensions
  # Parameters of the small scales decomposition
  n_block_x = FLAGS.nx
  n_block_y = FLAGS.ny
  n_block_z = 1
  halo_size = FLAGS.hsize

  # Parameters of the large scales decomposition
  downsampling_factor = FLAGS.dsample
  lnc = nc // 2**downsampling_factor

  fx_dim = mtf.Dimension("nx", nc)
  fy_dim = mtf.Dimension("ny", nc)
  fz_dim = mtf.Dimension("nz", nc)

  tfx_dim = mtf.Dimension("tx", nc)
  tfy_dim = mtf.Dimension("ty", nc)
  tfz_dim = mtf.Dimension("tz", nc)

  # Dimensions of the low resolution grid
  x_dim = mtf.Dimension("nx_lr", lnc)
  y_dim = mtf.Dimension("ny_lr", lnc)
  z_dim = mtf.Dimension("nz_lr", lnc)

  tx_dim = mtf.Dimension("tx_lr", lnc)
  ty_dim = mtf.Dimension("ty_lr", lnc)
  tz_dim = mtf.Dimension("tz_lr", lnc)

  nx_dim = mtf.Dimension('nx_block', n_block_x)
  ny_dim = mtf.Dimension('ny_block', n_block_y)
  nz_dim = mtf.Dimension('nz_block', n_block_z)

  sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
  sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
  sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

  batch_dim = mtf.Dimension("batch", batch_size)
  pk_dim = mtf.Dimension("npk", len(plin))
  pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  kx = mtf.import_tf_tensor(
      mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim])
  ky = mtf.import_tf_tensor(
      mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim])
  kz = mtf.import_tf_tensor(
      mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim])
  kv = [ky, kz, kx]

  # kvec for low resolution grid
  kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False)

  kx_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[tx_dim])
  ky_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[ty_dim])
  kz_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[tz_dim])
  kv_lr = [ky_lr, kz_lr, kx_lr]

  # kvec for high resolution blocks
  padded_sx_dim = mtf.Dimension('padded_sx_block',
                                nc // n_block_x + 2 * halo_size)
  padded_sy_dim = mtf.Dimension('padded_sy_block',
                                nc // n_block_y + 2 * halo_size)
  padded_sz_dim = mtf.Dimension('padded_sz_block',
                                nc // n_block_z + 2 * halo_size)
  kvec_hr = flowpm.kernels.fftk([
      nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
      nc // n_block_z + 2 * halo_size
  ],
                                symmetric=False)

  kx_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim])
  ky_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim])
  kz_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim])
  kv_hr = [ky_hr, kz_hr, kx_hr]

  shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  lr_shape = [batch_dim, x_dim, y_dim, z_dim]
  hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
  part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

  # Compute initial initial conditions distributed
  initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

  # Reshaping array into high resolution mesh
  field = mtf.slicewise(
      lambda x: tf.expand_dims(
          tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [initc],
      output_dtype=tf.float32,
      output_shape=hr_shape,
      name='my_reshape',
      splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

  for block_size_dim in hr_shape[-3:]:
    field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

  for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
    field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

  field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
  high = field
  low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

  low = mtf.reshape(low, low.shape[:-1])
  high = mtf.reshape(high, high.shape[:-1])

  for block_size_dim in hr_shape[-3:]:
    low = mtf.slice(low, halo_size // 2**downsampling_factor,
                    block_size_dim.size // 2**downsampling_factor,
                    block_size_dim.name)
  # Hack usisng  custom reshape because mesh is pretty dumb
  low = mtf.slicewise(
      lambda x: x[:, 0, 0, 0], [low],
      output_dtype=tf.float32,
      output_shape=lr_shape,
      name='my_dumb_reshape',
      splittable_dims=lr_shape[:-1] + hr_shape[:4])

  state = mtfpm.lpt_init(
      low,
      high,
      0.1,
      kv_lr,
      kv_hr,
      halo_size,
      hr_shape,
      lr_shape,
      part_shape[1:],
      downsampling_factor=downsampling_factor,
      antialias=True,
  )

  final_state = mtfpm.nbody(
      state,
      stages,
      lr_shape,
      hr_shape,
      kv_lr,
      kv_hr,
      halo_size,
      downsampling_factor=downsampling_factor)

  # paint the field
  final_field = mtf.zeros(mesh, shape=hr_shape)
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.pad(final_field, [halo_size, halo_size],
                          block_size_dim.name)
  final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
  # Halo exchange
  for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]):
    final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                  halo_size)
  # Remove borders
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                            block_size_dim.name)

  final_field = mtf.slicewise(
      lambda x: x[:, 0, 0, 0], [final_field],
      output_dtype=tf.float32,
      output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
      name='my_dumb_reshape',
      splittable_dims=part_shape[:-1] + hr_shape[:4])

  return initc, final_field
def benchmark_model(mesh):
  """
  Initializes a 3D volume with random noise, and execute a forward FFT
  """
  # Setup parameters
  bs = FLAGS.box_size
  nc = FLAGS.cube_size
  batch_size = FLAGS.batch_size
  a0 = FLAGS.a0
  a = 1.0
  nsteps = FLAGS.pm_steps

  # Compute a few things first, using simple tensorflow
  klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
  plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
  ipklin = iuspline(klin, plin)
  stages = np.linspace(a0, a, nsteps, endpoint=True)

  # Initialize the integration steps
  stages = np.linspace(FLAGS.a0, 1.0, FLAGS.pm_steps, endpoint=True)

  # Generate a batch of 3D initial conditions
  initial_conditions = flowpm.linear_field(
      nc,  # size of the cube
      bs,  # Physical size of the cube
      ipklin,  # Initial power spectrum
      batch_size=batch_size)

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  from flowpm.kernels import laplace_kernel, gradient_kernel
  lap = tf.cast(laplace_kernel(kvec), tf.complex64)
  grad_x = gradient_kernel(kvec, 0)
  grad_y = gradient_kernel(kvec, 1)
  grad_z = gradient_kernel(kvec, 2)

  # Define the named dimensions
  # Parameters of the small scales decomposition
  n_block_x = 8
  n_block_y = 4
  n_block_z = 1
  halo_size = 4

  # Parameters of the large scales decomposition
  downsampling_factor = 2
  lnc = nc // 2**downsampling_factor

  fx_dim = mtf.Dimension("nx", nc)
  fy_dim = mtf.Dimension("ny", nc)
  fz_dim = mtf.Dimension("nz", nc)

  tfx_dim = mtf.Dimension("tx", nc)
  tfy_dim = mtf.Dimension("ty", nc)
  tfz_dim = mtf.Dimension("tz", nc)

  # Dimensions of the low resolution grid
  tx_dim = mtf.Dimension("tx_lr", nc)
  ty_dim = mtf.Dimension("ty_lr", nc)
  tz_dim = mtf.Dimension("tz_lr", nc)

  nx_dim = mtf.Dimension('nx_block', n_block_x)
  ny_dim = mtf.Dimension('ny_block', n_block_y)
  nz_dim = mtf.Dimension('nz_block', n_block_z)

  sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
  sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
  sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

  batch_dim = mtf.Dimension("batch", batch_size)
  pk_dim = mtf.Dimension("npk", len(plin))
  pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  kx = mtf.import_tf_tensor(mesh,
                            kvec[0].squeeze().astype('float32'),
                            shape=[tfx_dim])
  ky = mtf.import_tf_tensor(mesh,
                            kvec[1].squeeze().astype('float32'),
                            shape=[tfy_dim])
  kz = mtf.import_tf_tensor(mesh,
                            kvec[2].squeeze().astype('float32'),
                            shape=[tfz_dim])
  kv = [ky, kz, kx]

  kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
  kx_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[0].squeeze().astype('float32'),
                               shape=[tx_dim])
  ky_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[1].squeeze().astype('float32'),
                               shape=[ty_dim])
  kz_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[2].squeeze().astype('float32'),
                               shape=[tz_dim])
  kv_lr = [ky_lr, kz_lr, kx_lr]

  # kvec for high resolution blocks
  shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
  part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

  initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)
  state = mtfpm.lpt_init_single(
      initc,
      a0,
      kv_lr,
      halo_size,
      lr_shape,
      hr_shape,
      part_shape[1:],
      antialias=True,
  )
  #state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape,
  #                       part_shape[1:], downsampling_factor=downsampling_factor, antialias=True,)

  # Here we can run our nbody
  final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

  # paint the field
  final_field = mtf.zeros(mesh, shape=hr_shape)
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.pad(final_field, [halo_size, halo_size],
                          block_size_dim.name)
  final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
  # Halo exchange
  for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]):
    final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                  halo_size)
  # Remove borders
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                            block_size_dim.name)

  #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
  # Hack usisng  custom reshape because mesh is pretty dumb
  final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field],
                              output_dtype=tf.float32,
                              output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                              name='my_dumb_reshape',
                              splittable_dims=part_shape[:-1] + hr_shape[:4])

  return mtf.reduce_sum(final_field)
Exemple #14
0
def lpt_prototype(mesh,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 0
    lnc = nc // 2**downsampling_factor

    #

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    #    # Reshaping array into high resolution mesh
    #    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
    #                      [initc],
    #                      output_dtype=tf.float32,
    #                      output_shape=hr_shape,
    #                      name='my_reshape',
    #                      splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3])
    #

    state = mtfpm.lpt_init_single(
        initc,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    # Here we can run our nbody
    final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return initc, final_field
def gradient_based_subword_tokenization(x,
                                        length_dim,
                                        max_subword_length=4,
                                        downsample=None,
                                        use_offsets=False,
                                        consider_chars_as_blocks=False,
                                        use_block_pos_embedding=False,
                                        share_block_kernel=False,
                                        memory_embeddings=0,
                                        context=None,
                                        block_mixing_mode=None,
                                        activation="softmax",
                                        downsample_function="mean"):
    """Implements GBSWT from Charformer.

  Args:
    x: a Tensor containing length_dim
    length_dim: a Dimension
    max_subword_length: integer
    downsample: integer.
    use_offsets: boolean.
    consider_chars_as_blocks: boolean.
    use_block_pos_embedding: boolean.
    share_block_kernel: boolean.
    memory_embeddings: integer.
    context: Context.
    block_mixing_mode: Str for block mixing.
    activation: Str for block ranking.
    downsample_function: Str, supports mean/linformer for now.

  Returns:
    a Tensor with the same shape as x.

  Raises:
    ValueError: if channels or depth don't match.
  """
    # don't use this for now.
    del max_subword_length
    del memory_embeddings
    all_blocks = []
    all_scores = []
    tf.logging.info("GSW block layer")

    def _tile(x, n, tile_dim):
        # Simple tile function in MTF.
        return mtf.concat([x] * n, tile_dim.name)

    def _repeat(x, n, repeat_dim):
        # repeat function in MTF
        tmp_dim = mtf.Dimension("tmp", 1)
        expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
        x = mtf.reshape(x, expand_shape)
        x = _tile(x, n, tmp_dim)
        output_shape = []
        for dim in x.shape.dims:
            if dim.name == "tmp":
                continue
            if dim.name == repeat_dim.name:
                dim = mtf.Dimension(dim.name, dim.size * n)
            output_shape.append(dim)
        output_shape = mtf.Shape(output_shape)
        x = mtf.reshape(x, output_shape)
        return x

    def _combined_dim(dims):
        return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)

    # compute all subword blocks
    # TODO(yitay): handle offsets to get all blocks
    if activation == "sigtanh":
        # one score for sigmoid
        tmp_dim = mtf.Dimension("block_score", 2)
    else:
        tmp_dim = mtf.Dimension("block_score", 1)

    model_dim = x.shape[-1]
    subword_blocks_width = [2, 3, 4]

    if consider_chars_as_blocks:
        subword_blocks_width += [1]

    if share_block_kernel:
        block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
        block_kernel = mtf.get_variable(x.mesh,
                                        "block_kernel",
                                        block_kernel_shape,
                                        initializer=None,
                                        dtype=context.variable_dtype)
    else:
        block_kernel = None

    for subword_len in subword_blocks_width:
        if use_block_pos_embedding:
            # this is turn off by default. It is meant to support cases like
            # parameterized pooling or other features.
            block_len_dim = mtf.Dimension(length_dim.name, subword_len)
            # TODO(vqtran): Consider other positional embeddings.
            block_pos_emb = sinusoid_positional_embedding_weights(
                context.mesh, block_len_dim, x.shape[-1],
                context.variable_dtype.activation_dtype)
            block_pos_emb = _repeat(
                block_pos_emb, math.ceil(length_dim.size / float(subword_len)),
                block_len_dim)
        if use_offsets:
            offset_space = subword_len
        else:
            offset_space = 1
        for offsets in range(offset_space):
            if offsets > 0:
                xoff = mtf.shift(x, offsets, length_dim, wrap=False)
                if use_block_pos_embedding:
                    block_pos_emb = mtf.shift(block_pos_emb,
                                              offsets,
                                              block_pos_emb.shape[-2],
                                              wrap=False)
            else:
                xoff = x
            tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
            if length_dim.size % subword_len != 0:
                tf.logging.info("Not divisible by length")
                # add extra padding tokens
                pad_amt = int(subword_len) - int(length_dim.size % subword_len)
                kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
            else:
                kp = xoff

            if use_block_pos_embedding:
                kp += block_pos_emb

            bx = mtf.pool_tensor_1d(
                kp,
                pool_dim=kp.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(subword_len))
            block_score = mtf.layers.dense(bx, [tmp_dim],
                                           use_bias=False,
                                           name="bx",
                                           reduced_dims=[model_dim],
                                           variable_dtype=None,
                                           kernel_weights=block_kernel)

            expand_bx = _repeat(bx, subword_len, length_dim)
            expand_scores = _repeat(block_score, subword_len, length_dim)
            if offsets > 0:
                # add offset.
                expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [offsets, 0],
                                        length_dim.name)
            new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
            if new_len.size < length_dim.size:
                pad_amt = new_len.size - length_dim.size
                expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [0, pad_amt],
                                        length_dim.name)
            elif new_len.size > length_dim.size:
                expand_bx = mtf.slice(expand_bx, 0, length_dim.size,
                                      length_dim.name)
                expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
                                          length_dim.name)

            new_tmp_dim = mtf.Dimension("extra_dim", 1)
            expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
            expand_scores_shape = mtf.Shape(expand_scores.shape.dims +
                                            [new_tmp_dim])
            expand_bx = mtf.reshape(expand_bx, expand_shape)
            expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
            all_blocks.append(expand_bx)
            all_scores.append(expand_scores)

    all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
    all_scores = mtf.concat(all_scores, new_tmp_dim.name)
    tf.logging.info(all_blocks)
    new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
    combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
    block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
    block_net = mtf.reshape(all_scores, block_net_shape)

    if block_mixing_mode == "score_attention":
        tf.logging.info("Using score attention")
        att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
        tf.logging.info(block_net)
        att = mtf.softmax(att, reduced_dim=att.shape[-1])
        block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
        tf.logging.info(block_net)

    if activation == "softmax":
        block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
    elif activation == "tanh":
        tf.logging.info("Using tanh")
        block_net = mtf.tanh(block_net)

    all_blocks = block_net * all_blocks
    all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
    output = all_blocks

    if downsample:
        output_length = output.shape.get_dim_by_name("length")
        if output_length.size % int(downsample) != 0:
            pad_amt = int(downsample) - int(
                output_length.size % int(downsample))
            output = mtf.pad(output, [0, pad_amt], output_length.name)
        if downsample_function == "mean":
            output = mtf.pool_tensor_1d(
                output,
                pool_dim=output.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(downsample))
        else:
            raise ValueError("Downsampling function not implemeneted.")

    return output
Exemple #16
0
def lpt_prototype(mesh,
                  initial_conditions,
                  derivs,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    stages = np.linspace(a0, a, nsteps, endpoint=True)
    lap, grad_x, grad_y, grad_z = derivs
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition
    downsampling_factor = 0
    lnc = nc // 2**downsampling_factor

    #

    ffx_dim = mtf.Dimension("fnx", nc)
    ffy_dim = mtf.Dimension("fny", nc)
    ffz_dim = mtf.Dimension("fnz", nc)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]

    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    field = mtf.import_tf_tensor(mesh, initial_conditions, shape=part_shape)

    state = mtfpm.lpt_init_single(
        field,
        a,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    print('TOTO', state)
    # Here we can run our nbody
    final_state = state
    # final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
    #                                  kv_lr, halo_size)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    final_field = mtf.reshape(final_field,
                              [batch_dim, ffx_dim, ffy_dim, ffz_dim])
    return final_field
Exemple #17
0
def recon_model(mesh,
                data,
                R0,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))
    print("\nfieldvar : \n", fieldvar)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            fieldvar,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3  #*nc**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.constant(R0)
    print("R0 in the recon_model : ", R0)

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    # Element-wise function that applies a Fourier kernel
    plambda = FLAGS.plambda

    def _cwise_logprob(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        logprob = galmean.log_prob(data)
        return -1 * logprob

    cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype)
    cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype)
    final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype)
    chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata],
                      output_dtype=tf.float32)  #
    chisq = mtf.reduce_sum(chisq)
    ##    #

    loss = chisq + prior

    def _cwise_sample(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        sample = galmean.sample()
        return sample

    sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata],
                       output_dtype=tf.float32)  #
    fields = [fieldvar, sample]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
def lpt_prototype(mesh,
                  initial_conditions,
                  derivs,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    stages = np.linspace(a0, a, nsteps, endpoint=True)
    lap, grad_x, grad_y, grad_z = derivs
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition
    downsampling_factor = FLAGS.dsample
    lnc = nc // 2**downsampling_factor

    #

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)
    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False)

    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype('float32'),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype('float32'),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype('float32'),
                                 shape=[padded_sz_dim])
    kv_hr = [kx_hr, ky_hr, kz_hr]

    lr_shape = [batch_dim, x_dim, y_dim, z_dim]

    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]

    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    initc = tf.reshape(
        initial_conditions,
        [1, n_block_x, nc // n_block_x, n_block_y, nc // n_block_y, 1, nc])
    initc = tf.transpose(initc, [0, 1, 3, 5, 2, 4, 6])
    field = mtf.import_tf_tensor(mesh, initc, shape=hr_shape)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=tf.float32,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Hack to handle reshape acrosss multiple dimensions
    #low = mtf.reshape(low, [batch_dim, x_dim, low.shape[2], low.shape[5], z_dim])
    #low = mtf.reshape(low, lr_shape)

    state = mtfpm.lpt_init(
        low,
        high,
        a0,
        kv_lr,
        kv_hr,
        halo_size,
        hr_shape,
        lr_shape,
        k_dims,
        part_shape[1:],
        downsampling_factor=downsampling_factor,
        antialias=True,
    )

    # Here we can run our nbody
    final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return final_field
Exemple #19
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    #
    # Begin simulation

    ## Compute initial initial conditions distributed
    #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    fieldvar = mtf.get_variable(mesh, 'linear', part_shape)
    input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    linearop = mtf.assign(fieldvar, mtfinp)

    #field = fieldvar
    initc = fieldvar

    print("initc : ", initc)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])
    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    print("mmy : ", mmy)
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    width = tf.placeholder(tf.float32, shape=())

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(width * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    #k_dims = [d.shape[0] for d in kv]
    #k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)
    #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype)
    #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype)
    #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    R0 = tf.placeholder(tf.float32, shape=())
    M0 = tf.placeholder(tf.float32, shape=())
    off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder(
        tf.float32, shape=data.shape)
    mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
    mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0)
    #diff = diff / 0.25
    #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one
    diff = (diff + mtfoff) / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype)
    var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 2
    lnc = nc // 2**downsampling_factor

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype(npdtype),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype(npdtype),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype(npdtype),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc],
                                  symmetric=False,
                                  dtype=npdtype)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)

    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False,
                                  dtype=npdtype)
    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype(npdtype),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype(npdtype),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype(npdtype),
                                 shape=[padded_sz_dim])
    kv_hr = [ky_hr, kz_hr, kx_hr]

    # kvec for prior blocks
    prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x)
    prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y)
    prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z)

    kvec_pr = flowpm.kernels.fftk(
        [nc // n_block_x, nc // n_block_y, nc // n_block_z],
        symmetric=False,
        dtype=npdtype)
    kx_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[0].squeeze().astype(npdtype),
                                 shape=[prior_sx_dim])
    ky_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[1].squeeze().astype(npdtype),
                                 shape=[prior_sy_dim])
    kz_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[2].squeeze().astype(npdtype),
                                 shape=[prior_sz_dim])
    kv_pr = [ky_pr, kz_pr, kx_pr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, x_dim, y_dim, z_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    ## Compute initial initial conditions distributed

    fieldvar = mtf.get_variable(mesh, 'linear', hr_shape)
    input_field = tf.placeholder(data.dtype, [
        batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x,
        nc // n_block_y, nc // n_block_z
    ])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape)
    linearop = mtf.assign(fieldvar, mtfinp)
    #
    field = fieldvar
    initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                          output_dtype=tf.float32,
                          output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                          name='my_dumb_reshape',
                          splittable_dims=part_shape[:-1] + hr_shape[:4])

    #
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=initc.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init(low,
                               high,
                               0.1,
                               kv_lr,
                               kv_hr,
                               halo_size,
                               hr_shape,
                               lr_shape,
                               part_shape[1:],
                               downsampling_factor=downsampling_factor,
                               antialias=True)

        final_state = mtfpm.nbody(state,
                                  stages,
                                  lr_shape,
                                  hr_shape,
                                  kv_lr,
                                  kv_hr,
                                  halo_size,
                                  downsampling_factor=downsampling_factor)
    else:
        final_state = mtfpm.lpt_init(low,
                                     high,
                                     stages[-1],
                                     kv_lr,
                                     kv_hr,
                                     halo_size,
                                     hr_shape,
                                     lr_shape,
                                     part_shape[1:],
                                     downsampling_factor=downsampling_factor,
                                     antialias=True)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv_pr]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.placeholder(tf.float32, shape=())

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype)
    var_grads = [
        mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype)
    ]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0