Ejemplo n.º 1
0
def Pooling(tsr,
            fltr,
            stride=(1, 1),
            padding='VALID',
            pooling_fn=tf.nn.max_pool,
            name=None):
    stride, padding = NormalizeStrideAndPad(stride, padding)
    with tf.variable_scope(name, default_name='pool'):

        def max_pool(x):
            return pooling_fn(x, [1, fltr[0], fltr[1], 1],
                              [1, stride[0], stride[1], 1], padding)

        # Output shape
        h_o = tsr.shape[1].size
        w_o = tsr.shape[2].size
        if padding == 'VALID':
            h_o -= fltr[0]
            w_o -= fltr[1]
        h_o //= stride[0]
        w_o //= stride[1]
        if padding == 'VALID':
            h_o += 1
            w_o += 1

        output_shape = tsr.shape.resize_dimension(tsr.shape[1].name, h_o)
        output_shape = output_shape.resize_dimension(tsr.shape[2].name, w_o)

        splittable_dims = [tsr.shape.dims[0], tsr.shape.dims[-1]]
        out = mtf.slicewise(max_pool, [tsr],
                            output_shape,
                            tsr.dtype,
                            splittable_dims,
                            grad_function=GenericGradFn)
        return out
Ejemplo n.º 2
0
def apply_gradient_laplace_kernel(kfield, kvec, order=1):
  """
  Computes gradients in Fourier space along all three spatial directions
  """
  # TODO: support order=0 kernel
  assert order == 1

  def _swise_fn(kfield, kx, ky, kz):
    kx = tf.reshape(kx, [1, -1, 1, 1])
    ky = tf.reshape(ky, [1, 1, -1, 1])
    kz = tf.reshape(kz, [1, 1, 1, -1])
    kk = (kx**2 + ky**2 + kz**2)
    wts = tf.where(kk > 0, 1. / kk, tf.zeros_like(kk))
    kfield = kfield * tf.cast(wts, kfield.dtype)
    dkfield_dx = kfield * 1.j / 6.0 * tf.cast(8 * tf.sin(kx) - tf.sin(2 * kx),
                                              kfield.dtype)
    dkfield_dy = kfield * 1.j / 6.0 * tf.cast(8 * tf.sin(ky) - tf.sin(2 * ky),
                                              kfield.dtype)
    dkfield_dz = kfield * 1.j / 6.0 * tf.cast(8 * tf.sin(kz) - tf.sin(2 * kz),
                                              kfield.dtype)
    return dkfield_dx, dkfield_dy, dkfield_dz

  dkfield_dx, dkfield_dy, dkfield_dz = mtf.slicewise(
      _swise_fn, [kfield] + kvec,
      output_shape=[kfield.shape] * 3,
      output_dtype=[tf.complex64] * 3,
      splittable_dims=kfield.shape[:])
  return dkfield_dx, dkfield_dy, dkfield_dz
def cic_readout(mesh, part, halo_size, name=None):
  nk = mtf.Dimension("nk", 8)
  nl = mtf.Dimension("nl", 4)
  indices, values = mtf.slicewise(
      _cic_indexing, [mesh, part],
      output_dtype=[tf.float32, tf.float32],
      output_shape=[
          mtf.Shape(part.shape.dims[:-1] + [nk, nl]),
          mtf.Shape(part.shape.dims[:-1] + [nk])
      ],
      splittable_dims=mesh.shape[:-3] + part.shape[1:-1])
  value = mtf.slicewise(
      lambda x, y, z: _cic_readout(
          x, y, z, shift=[0, halo_size, halo_size, halo_size]),
      [mesh, indices, values],
      output_dtype=tf.float32,
      output_shape=part.shape[:-1],
      splittable_dims=mesh.shape[:-3] + part.shape[1:-1])
  return value
Ejemplo n.º 4
0
def cic_paint(mesh, part, halo_size, weight=None, name=None):
    """
  Distributed Cloud In Cell implementation.

  Parameters:
  -----------
  mesh: tensor (batch_size, nc, nc, nc)
    Input 3D mesh tensor

  part: tensor (batch_size, npart, 3)
    List of 3D particle coordinates, assumed to be in mesh units if
    boxsize is None

  splitted_dims: list of Dimensions in mesh
    List of dimensions along which the particles are splitted

  weight: tensor (batch_size, npart)
    List of weights  for each particle

  ----
  The current implementation applies slicewise CiC paintings, and send/recv the
  part tensor with neighboring devices on the mesh to make sure particles
  displaced outside of their initial boundaries get a chance to be painted.
  """
    nk = mtf.Dimension("nk", 8)
    nl = mtf.Dimension("nl", 4)
    indices, values = mtf.slicewise(
        _cic_indexing, [mesh, part],
        output_dtype=[tf.float32, tf.float32],
        output_shape=[
            mtf.Shape(part.shape.dims[:-1] + [nk, nl]),
            mtf.Shape(part.shape.dims[:-1] + [nk])
        ],
        splittable_dims=mesh.shape[:-3] + part.shape[1:-1])

    mesh = mtf.slicewise(lambda x, y, z: _cic_paint(
        x, y, z, shift=[0, halo_size, halo_size, halo_size]),
                         [mesh, indices, values],
                         output_dtype=tf.float32,
                         output_shape=mesh.shape,
                         splittable_dims=mesh.shape[:-3] + part.shape[1:-1])
    return mesh
Ejemplo n.º 5
0
def lpt_init_single(lr_field, a0, kvec_lr, halo_size, lr_shape, hr_shape, part_shape, antialias=True, order=1, post_filtering=True, cosmology=Planck15):
  a = a0
  batch_dim = lr_field.shape[0]
  lnc = lr_shape[-1].size

  # Create particles on the high resolution grid
  mstate = mesh_ops.mtf_indices(lr_field.mesh, shape=part_shape, dtype=tf.float32)
  X = mtf.einsum([mtf.ones(lr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:])


  k_dims_lr = [d.shape[0] for d in kvec_lr]
  k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]

  lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)

  grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)

  # Reorder the low res FFTs which where transposed# y,z,x
  grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]]


  displacement = []
  for f in grad_kfield_lr:
    f = mesh_utils.c2r3d(f, lr_shape[-3:])
    f = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                      [f],
                      output_dtype=tf.float32,
                      output_shape=mtf.Shape(hr_shape[0:4]+[
                        mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                        mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                        mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                      name='my_reshape',
                      splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3])

    for block_size_dim in hr_shape[-3:]:
      f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
      f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size)
    d =  mesh_utils.cic_readout(f, X, halo_size)
    displacement.append(d)
  # Readout to particle positions
  displacement = mtf.stack([ d for d in displacement],"ndim",axis=4)

  pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0)
  DX = pt.D1(a) * displacement
  P = (a ** 2 * pt.f1(a) * pt.E(a)) * DX
  F = (a ** 2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX
  # TODO: Implement 2nd order LPT

  # Moves the particles according to displacement
  X = X + DX

  return X, P, F
Ejemplo n.º 6
0
def cic_paint_fr(field, state, output_shape, hr_shape, halo_size, splittables, mesh, weights=None):
    '''paint the position from state to a field of batch+3D tensor
    Ops performed :
    - reshape to hr_shape
    - pad
    - paint
    - halo_reduce
    - slice to remove pad
    - reshape to output_shape
    '''
    lnc = field.shape[-1].size
    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                          [field],
                          output_dtype=tf.float32,
                          output_shape=mtf.Shape(hr_shape[0:4]+[
                              mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                              mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                              mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                          name='my_reshape',
                          splittable_dims=splittables)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    field = mesh_utils.cic_paint(field, state[0], halo_size, weights)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        field = mtf.slice(field, halo_size, block_size_dim.size, block_size_dim.name)

    field = mtf.slicewise(lambda x: x[:,0,0,0],
                        [field],
                          output_dtype=field.dtype,
                          output_shape=output_shape,
                          name='my_dumb_reshape',
                          splittable_dims=splittables)
    return field
Ejemplo n.º 7
0
def fr_to_hr(field, hr_shape, halo_size, splittables, mesh):
    # Reshaping array into high resolution mesh
    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                          [field],
                          output_dtype=field.dtype,
                          output_shape=hr_shape,
                          name='my_reshape',
                          splittable_dims=splittables)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    return field
Ejemplo n.º 8
0
def shear(rfield, cfield, kv, tfnc, tfbs):

    wtss = mtf.slicewise(cswisef.shearwts, [cfield] + kv + [tfnc, tfbs],
                         output_shape=[cfield.shape] * 6,
                         output_dtype=[cfield.dtype] * 6,
                         splittable_dims=cfield.shape[:])

    counter = 0
    for i in range(3):
        for j in range(i, 3):
            baser = mesh_utils.c2r3d(wtss[counter],
                                     rfield.shape[-3:],
                                     dtype=rfield.dtype)
            rfield = rfield + baser * baser
            if i != j: rfield = rfield + baser * baser
            counter += 1
    return rfield
Ejemplo n.º 9
0
def downsample_hr_to_lr(field, lr_shape, hr_shape, downsampling_factor, halo_size, splittables, mesh):
    # Reshaping array into high resolution mesh
    field = mtf.reshape(field, field.shape+[mtf.Dimension('h_dim', 1)])
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)
    low = mtf.reshape(low, low.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size//2**downsampling_factor, block_size_dim.size//2**downsampling_factor, block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:,0,0,0],
                        [low],
                        output_dtype=field.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=splittables)

    return low
Ejemplo n.º 10
0
def cic_readout_fr(field, state, hr_shape, halo_size, splittables, mesh):
    '''readout from at the position from state on a field of batch+3D tensor'''
    lnc = field.shape[-1].size
    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                          [field],
                          output_dtype=tf.float32,
                          output_shape=mtf.Shape(hr_shape[0:4]+[
                              mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                              mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                              mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                          name='my_reshape',
                          splittable_dims=splittables)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    #Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    read = mesh_utils.cic_readout(field, state[0], halo_size)
    return read
Ejemplo n.º 11
0
def recon_model(mesh,
                datasm,
                rsdfactor,
                M0,
                R0,
                width,
                off,
                istd,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]

    #
    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))

    ##
    state = mtfpm.lpt_init_single(
        fieldvar,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr,
                                     halo_size)

    final_field = mtf.zeros(mesh, shape=part_shape)
    final_field = mcomp.cic_paint_fr(final_field,
                                     final_state,
                                     output_shape=part_shape,
                                     hr_shape=hr_shape,
                                     halo_size=halo_size,
                                     splittables=splittables,
                                     mesh=mesh)

    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(tf.constant(width) * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    ##RSD below
    hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables,
                              mesh)
    mstate = mpm.mtf_indices(hr_field.mesh,
                             shape=part_shape[1:],
                             dtype=tf.float32)
    X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate],
                   output_shape=[batch_dim] + mstate.shape[:])

    massf = mesh_utils.r2c3d(final_field, k_dims, dtype=cdtype)
    masssmf = mtf.cwise(cwise_fingauss,
                        [massf, float_to_mtf(R1, mesh, scalar)] + kv +
                        [tfnc, tfbs],
                        output_dtype=cdtype)
    masssm = mesh_utils.c2r3d(masssmf, final_field.shape[-3:], dtype=dtype)
    masssm = masssm + 1e-5
    imasssm = mtf.pow(x, -1.)

    vzweights = final_state[1]
    vzweights = mtf.slicewise(lambda x: x[:, :, :, :, -1], [vzweights],
                              output_dtype=tf.float32,
                              output_shape=vzweights.shape[:-1],
                              name='get_vz',
                              splittable_dims=vzweights.shape[1:-1])
    print("weights : ", vzweights)

    momz = mtf.zeros(mesh, shape=part_shape)
    momz = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, \
                              halo_size=halo_size, splittables=splittables, mesh=mesh, weights=vzweights)
    momzf = mesh_utils.r2c3d(momz, k_dims, dtype=cdtype)
    momzsmf = mtf.cwise(cwise_fingauss,
                        [momzf, float_to_mtf(R1, mesh, scalar)] + kv +
                        [tfnc, tfbs],
                        output_dtype=cdtype)
    momzsm = mesh_utils.c2r3d(momzsmf, momz.shape[-3:], dtype=dtype)

    #Shift
    velzsm = mtf.divide(momzsm, masssm)
    vz = mcomp.cic_readout_fr(velzsm, [X],
                              hr_shape=hr_shape,
                              halo_size=halo_size,
                              splittables=splittables,
                              mesh=mesh)
    vz = mtf.multiply(vz, rsdfactor)
    print("vz : ", vz)

    Xrsd = mtf.slicewise(lambda x, vz: x + tf.stack(
        [tf.zeros_like(vz), tf.zeros_like(vz), vz], 4), [X, vzweights],
                         output_dtype=tf.float32,
                         output_shape=X.shape,
                         name='add_vz',
                         splittable_dims=X.shape[1:-1])
    print(Xrsd)
    modelread = mcomp.cic_readout_fr(model, [X],
                                     hr_shape=hr_shape,
                                     halo_size=halo_size,
                                     splittables=splittables,
                                     mesh=mesh)
    modelrsd = mtf.zeros(mesh, shape=part_shape)
    modelrsd = mcomp.cic_paint_fr(modelrsd, [Xrsd], output_shape=part_shape, hr_shape=hr_shape, \
                                  halo_size=halo_size, splittables=splittables, mesh=mesh, weights=modelread)

    model = modelrsd
    print(modelrsd)

    #Likelihood and prior here
    mtfdatasm = mtf.import_tf_tensor(mesh,
                                     tf.convert_to_tensor(datasm),
                                     shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    M0 = tf.constant(M0)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdatasm + M0)
    if off is not None:
        mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
        diff = diff + mtfoff
    if istd is not None:
        mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
        diff = (diff + mtfoff
                ) * mtfistd  #For some reason, doing things wrong this one
    else:
        diff = diff / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    fields = [fieldvar, final_field, model]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
Ejemplo n.º 12
0
def nbody_prototype(mesh,
                    infield=False,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin, shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    ## Compute initial initial conditions distributed
    input_field = tf.placeholder(dtype, [batch_size, nc, nc, nc])
    if infield:
        initc = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    else:
        initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            initc,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return initc, final_field, input_field
Ejemplo n.º 13
0
def nbody_fn(mesh,
             klin,
             plin,
             nc=FLAGS.nc,
             bs=FLAGS.box_size,
             batch_size=FLAGS.batch_size,
             a0=FLAGS.a0,
             a=FLAGS.af,
             nsteps=FLAGS.nsteps,
             dtype=tf.float32):
  """ Pyramid N-body function
  """
  stages = np.linspace(a0, a, nsteps, endpoint=True)

  # Define the named dimensions
  # Parameters of the small scales decomposition
  n_block_x = FLAGS.nx
  n_block_y = FLAGS.ny
  n_block_z = 1
  halo_size = FLAGS.hsize

  # Parameters of the large scales decomposition
  downsampling_factor = FLAGS.dsample
  lnc = nc // 2**downsampling_factor

  fx_dim = mtf.Dimension("nx", nc)
  fy_dim = mtf.Dimension("ny", nc)
  fz_dim = mtf.Dimension("nz", nc)

  tfx_dim = mtf.Dimension("tx", nc)
  tfy_dim = mtf.Dimension("ty", nc)
  tfz_dim = mtf.Dimension("tz", nc)

  # Dimensions of the low resolution grid
  x_dim = mtf.Dimension("nx_lr", lnc)
  y_dim = mtf.Dimension("ny_lr", lnc)
  z_dim = mtf.Dimension("nz_lr", lnc)

  tx_dim = mtf.Dimension("tx_lr", lnc)
  ty_dim = mtf.Dimension("ty_lr", lnc)
  tz_dim = mtf.Dimension("tz_lr", lnc)

  nx_dim = mtf.Dimension('nx_block', n_block_x)
  ny_dim = mtf.Dimension('ny_block', n_block_y)
  nz_dim = mtf.Dimension('nz_block', n_block_z)

  sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
  sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
  sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

  batch_dim = mtf.Dimension("batch", batch_size)
  pk_dim = mtf.Dimension("npk", len(plin))
  pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  kx = mtf.import_tf_tensor(
      mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim])
  ky = mtf.import_tf_tensor(
      mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim])
  kz = mtf.import_tf_tensor(
      mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim])
  kv = [ky, kz, kx]

  # kvec for low resolution grid
  kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False)

  kx_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[tx_dim])
  ky_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[ty_dim])
  kz_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[tz_dim])
  kv_lr = [ky_lr, kz_lr, kx_lr]

  # kvec for high resolution blocks
  padded_sx_dim = mtf.Dimension('padded_sx_block',
                                nc // n_block_x + 2 * halo_size)
  padded_sy_dim = mtf.Dimension('padded_sy_block',
                                nc // n_block_y + 2 * halo_size)
  padded_sz_dim = mtf.Dimension('padded_sz_block',
                                nc // n_block_z + 2 * halo_size)
  kvec_hr = flowpm.kernels.fftk([
      nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
      nc // n_block_z + 2 * halo_size
  ],
                                symmetric=False)

  kx_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim])
  ky_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim])
  kz_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim])
  kv_hr = [ky_hr, kz_hr, kx_hr]

  shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  lr_shape = [batch_dim, x_dim, y_dim, z_dim]
  hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
  part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

  # Compute initial initial conditions distributed
  initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

  # Reshaping array into high resolution mesh
  field = mtf.slicewise(
      lambda x: tf.expand_dims(
          tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [initc],
      output_dtype=tf.float32,
      output_shape=hr_shape,
      name='my_reshape',
      splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

  for block_size_dim in hr_shape[-3:]:
    field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

  for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
    field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

  field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
  high = field
  low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

  low = mtf.reshape(low, low.shape[:-1])
  high = mtf.reshape(high, high.shape[:-1])

  for block_size_dim in hr_shape[-3:]:
    low = mtf.slice(low, halo_size // 2**downsampling_factor,
                    block_size_dim.size // 2**downsampling_factor,
                    block_size_dim.name)
  # Hack usisng  custom reshape because mesh is pretty dumb
  low = mtf.slicewise(
      lambda x: x[:, 0, 0, 0], [low],
      output_dtype=tf.float32,
      output_shape=lr_shape,
      name='my_dumb_reshape',
      splittable_dims=lr_shape[:-1] + hr_shape[:4])

  state = mtfpm.lpt_init(
      low,
      high,
      0.1,
      kv_lr,
      kv_hr,
      halo_size,
      hr_shape,
      lr_shape,
      part_shape[1:],
      downsampling_factor=downsampling_factor,
      antialias=True,
  )

  final_state = mtfpm.nbody(
      state,
      stages,
      lr_shape,
      hr_shape,
      kv_lr,
      kv_hr,
      halo_size,
      downsampling_factor=downsampling_factor)

  # paint the field
  final_field = mtf.zeros(mesh, shape=hr_shape)
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.pad(final_field, [halo_size, halo_size],
                          block_size_dim.name)
  final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
  # Halo exchange
  for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]):
    final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                  halo_size)
  # Remove borders
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                            block_size_dim.name)

  final_field = mtf.slicewise(
      lambda x: x[:, 0, 0, 0], [final_field],
      output_dtype=tf.float32,
      output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
      name='my_dumb_reshape',
      splittable_dims=part_shape[:-1] + hr_shape[:4])

  return initc, final_field
Ejemplo n.º 14
0
def benchmark_model(mesh):
  """
  Initializes a 3D volume with random noise, and execute a forward FFT
  """
  # Setup parameters
  bs = FLAGS.box_size
  nc = FLAGS.cube_size
  batch_size = FLAGS.batch_size
  a0 = FLAGS.a0
  a = 1.0
  nsteps = FLAGS.pm_steps

  # Compute a few things first, using simple tensorflow
  klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
  plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
  ipklin = iuspline(klin, plin)
  stages = np.linspace(a0, a, nsteps, endpoint=True)

  # Initialize the integration steps
  stages = np.linspace(FLAGS.a0, 1.0, FLAGS.pm_steps, endpoint=True)

  # Generate a batch of 3D initial conditions
  initial_conditions = flowpm.linear_field(
      nc,  # size of the cube
      bs,  # Physical size of the cube
      ipklin,  # Initial power spectrum
      batch_size=batch_size)

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  from flowpm.kernels import laplace_kernel, gradient_kernel
  lap = tf.cast(laplace_kernel(kvec), tf.complex64)
  grad_x = gradient_kernel(kvec, 0)
  grad_y = gradient_kernel(kvec, 1)
  grad_z = gradient_kernel(kvec, 2)

  # Define the named dimensions
  # Parameters of the small scales decomposition
  n_block_x = 8
  n_block_y = 4
  n_block_z = 1
  halo_size = 4

  # Parameters of the large scales decomposition
  downsampling_factor = 2
  lnc = nc // 2**downsampling_factor

  fx_dim = mtf.Dimension("nx", nc)
  fy_dim = mtf.Dimension("ny", nc)
  fz_dim = mtf.Dimension("nz", nc)

  tfx_dim = mtf.Dimension("tx", nc)
  tfy_dim = mtf.Dimension("ty", nc)
  tfz_dim = mtf.Dimension("tz", nc)

  # Dimensions of the low resolution grid
  tx_dim = mtf.Dimension("tx_lr", nc)
  ty_dim = mtf.Dimension("ty_lr", nc)
  tz_dim = mtf.Dimension("tz_lr", nc)

  nx_dim = mtf.Dimension('nx_block', n_block_x)
  ny_dim = mtf.Dimension('ny_block', n_block_y)
  nz_dim = mtf.Dimension('nz_block', n_block_z)

  sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
  sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
  sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

  batch_dim = mtf.Dimension("batch", batch_size)
  pk_dim = mtf.Dimension("npk", len(plin))
  pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  kx = mtf.import_tf_tensor(mesh,
                            kvec[0].squeeze().astype('float32'),
                            shape=[tfx_dim])
  ky = mtf.import_tf_tensor(mesh,
                            kvec[1].squeeze().astype('float32'),
                            shape=[tfy_dim])
  kz = mtf.import_tf_tensor(mesh,
                            kvec[2].squeeze().astype('float32'),
                            shape=[tfz_dim])
  kv = [ky, kz, kx]

  kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
  kx_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[0].squeeze().astype('float32'),
                               shape=[tx_dim])
  ky_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[1].squeeze().astype('float32'),
                               shape=[ty_dim])
  kz_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[2].squeeze().astype('float32'),
                               shape=[tz_dim])
  kv_lr = [ky_lr, kz_lr, kx_lr]

  # kvec for high resolution blocks
  shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
  part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

  initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)
  state = mtfpm.lpt_init_single(
      initc,
      a0,
      kv_lr,
      halo_size,
      lr_shape,
      hr_shape,
      part_shape[1:],
      antialias=True,
  )
  #state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape,
  #                       part_shape[1:], downsampling_factor=downsampling_factor, antialias=True,)

  # Here we can run our nbody
  final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

  # paint the field
  final_field = mtf.zeros(mesh, shape=hr_shape)
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.pad(final_field, [halo_size, halo_size],
                          block_size_dim.name)
  final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
  # Halo exchange
  for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]):
    final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                  halo_size)
  # Remove borders
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                            block_size_dim.name)

  #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
  # Hack usisng  custom reshape because mesh is pretty dumb
  final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field],
                              output_dtype=tf.float32,
                              output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                              name='my_dumb_reshape',
                              splittable_dims=part_shape[:-1] + hr_shape[:4])

  return mtf.reduce_sum(final_field)
Ejemplo n.º 15
0
def lpt_prototype(mesh,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 0
    lnc = nc // 2**downsampling_factor

    #

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    #    # Reshaping array into high resolution mesh
    #    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
    #                      [initc],
    #                      output_dtype=tf.float32,
    #                      output_shape=hr_shape,
    #                      name='my_reshape',
    #                      splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3])
    #

    state = mtfpm.lpt_init_single(
        initc,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    # Here we can run our nbody
    final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return initc, final_field
Ejemplo n.º 16
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    #
    # Begin simulation

    ## Compute initial initial conditions distributed
    #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    fieldvar = mtf.get_variable(mesh, 'linear', part_shape)
    input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    linearop = mtf.assign(fieldvar, mtfinp)

    #field = fieldvar
    initc = fieldvar

    print("initc : ", initc)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])
    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    print("mmy : ", mmy)
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    width = tf.placeholder(tf.float32, shape=())

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(width * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    #k_dims = [d.shape[0] for d in kv]
    #k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)
    #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype)
    #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype)
    #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    R0 = tf.placeholder(tf.float32, shape=())
    M0 = tf.placeholder(tf.float32, shape=())
    off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder(
        tf.float32, shape=data.shape)
    mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
    mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0)
    #diff = diff / 0.25
    #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one
    diff = (diff + mtfoff) / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype)
    var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
Ejemplo n.º 17
0
def recon_model(mesh,
                data,
                R0,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))
    print("\nfieldvar : \n", fieldvar)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            fieldvar,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3  #*nc**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.constant(R0)
    print("R0 in the recon_model : ", R0)

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    # Element-wise function that applies a Fourier kernel
    plambda = FLAGS.plambda

    def _cwise_logprob(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        logprob = galmean.log_prob(data)
        return -1 * logprob

    cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype)
    cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype)
    final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype)
    chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata],
                      output_dtype=tf.float32)  #
    chisq = mtf.reduce_sum(chisq)
    ##    #

    loss = chisq + prior

    def _cwise_sample(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        sample = galmean.sample()
        return sample

    sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata],
                       output_dtype=tf.float32)  #
    fields = [fieldvar, sample]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
def lpt_prototype(mesh,
                  initial_conditions,
                  derivs,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    stages = np.linspace(a0, a, nsteps, endpoint=True)
    lap, grad_x, grad_y, grad_z = derivs
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition
    downsampling_factor = FLAGS.dsample
    lnc = nc // 2**downsampling_factor

    #

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)
    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False)

    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype('float32'),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype('float32'),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype('float32'),
                                 shape=[padded_sz_dim])
    kv_hr = [kx_hr, ky_hr, kz_hr]

    lr_shape = [batch_dim, x_dim, y_dim, z_dim]

    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]

    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    initc = tf.reshape(
        initial_conditions,
        [1, n_block_x, nc // n_block_x, n_block_y, nc // n_block_y, 1, nc])
    initc = tf.transpose(initc, [0, 1, 3, 5, 2, 4, 6])
    field = mtf.import_tf_tensor(mesh, initc, shape=hr_shape)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=tf.float32,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Hack to handle reshape acrosss multiple dimensions
    #low = mtf.reshape(low, [batch_dim, x_dim, low.shape[2], low.shape[5], z_dim])
    #low = mtf.reshape(low, lr_shape)

    state = mtfpm.lpt_init(
        low,
        high,
        a0,
        kv_lr,
        kv_hr,
        halo_size,
        hr_shape,
        lr_shape,
        k_dims,
        part_shape[1:],
        downsampling_factor=downsampling_factor,
        antialias=True,
    )

    # Here we can run our nbody
    final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return final_field
Ejemplo n.º 19
0
def slicewise_c2r3d(cfield):
    rfield = mtf.slicewise(lambda x: tf.cast(tf.signal.ifft3d(x), tf.float32),
                           [cfield],
                           output_dtype=tf.float32,
                           splittable_dims=cfield.shape[:-3])
    return rfield
Ejemplo n.º 20
0
def slicewise_r2c3d(rfield):
    cfield = mtf.slicewise(lambda x: tf.signal.fft3d(tf.cast(x, tf.complex64)),
                           [rfield],
                           output_dtype=tf.complex64,
                           splittable_dims=rfield.shape[:-3])
    return cfield
Ejemplo n.º 21
0
def force(state,
          lr_shape,
          hr_shape,
          kvec_lr,
          kvec_hr,
          halo_size,
          cosmology=Planck15,
          downsampling_factor=2,
          pm_nc_factor=1,
          antialias=True,
          **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape
    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_hr = [d.shape[0] for d in kvec_hr]
    # Reorder the FFTs which where transposed# y,z,x
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]]

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)

    # Split the field into low and high resolution
    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)
    low = mtf.reshape(low, low.shape[:-1])
    hr_field = mtf.reshape(high, high.shape[:-1])
    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)
    hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr)

    kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield,
                                                    kvec_lr,
                                                    r_split=0)
    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)
    kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield,
                                                    kvec_hr,
                                                    r_split=0)
    kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]
    kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]]

    displacement = []
    for f, g in zip(kfield_lr, kfield_hr):
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])
        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [
                halo_size // 2**downsampling_factor,
                halo_size // 2**downsampling_factor
            ], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim,
                                     halo_size // 2**downsampling_factor)
        f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)])
        f = mesh_utils.upsample(f, downsampling_factor)
        f = mtf.reshape(f, f.shape[:-1])

        g = mesh_utils.c2r3d(g, f.shape[-3:])
        high_shape = g.shape
        # And now we remove the large scales
        g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)])
        _low = mesh_utils.downsample(g,
                                     downsampling_factor,
                                     antialias=antialias)
        g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor),
                            g.shape)
        g = mtf.reshape(g, high_shape)

        d = mesh_utils.cic_readout(f + g, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F
Ejemplo n.º 22
0
def lpt_prototype(mesh,
                  initial_conditions,
                  derivs,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    stages = np.linspace(a0, a, nsteps, endpoint=True)
    lap, grad_x, grad_y, grad_z = derivs
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition
    downsampling_factor = 0
    lnc = nc // 2**downsampling_factor

    #

    ffx_dim = mtf.Dimension("fnx", nc)
    ffy_dim = mtf.Dimension("fny", nc)
    ffz_dim = mtf.Dimension("fnz", nc)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]

    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    field = mtf.import_tf_tensor(mesh, initial_conditions, shape=part_shape)

    state = mtfpm.lpt_init_single(
        field,
        a,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    print('TOTO', state)
    # Here we can run our nbody
    final_state = state
    # final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
    #                                  kv_lr, halo_size)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    final_field = mtf.reshape(final_field,
                              [batch_dim, ffx_dim, ffy_dim, ffz_dim])
    return final_field
Ejemplo n.º 23
0
def force_single(state,
                 lr_shape,
                 hr_shape,
                 kvec_lr,
                 halo_size,
                 cosmology=Planck15,
                 pm_nc_factor=1,
                 **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        field = mtf.slice(field, halo_size, block_size_dim.size,
                          block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)

    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]

    displacement = []
    for f in kfield_lr:
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size)
        d = mesh_utils.cic_readout(f, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F
Ejemplo n.º 24
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 2
    lnc = nc // 2**downsampling_factor

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype(npdtype),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype(npdtype),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype(npdtype),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc],
                                  symmetric=False,
                                  dtype=npdtype)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)

    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False,
                                  dtype=npdtype)
    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype(npdtype),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype(npdtype),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype(npdtype),
                                 shape=[padded_sz_dim])
    kv_hr = [ky_hr, kz_hr, kx_hr]

    # kvec for prior blocks
    prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x)
    prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y)
    prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z)

    kvec_pr = flowpm.kernels.fftk(
        [nc // n_block_x, nc // n_block_y, nc // n_block_z],
        symmetric=False,
        dtype=npdtype)
    kx_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[0].squeeze().astype(npdtype),
                                 shape=[prior_sx_dim])
    ky_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[1].squeeze().astype(npdtype),
                                 shape=[prior_sy_dim])
    kz_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[2].squeeze().astype(npdtype),
                                 shape=[prior_sz_dim])
    kv_pr = [ky_pr, kz_pr, kx_pr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, x_dim, y_dim, z_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    ## Compute initial initial conditions distributed

    fieldvar = mtf.get_variable(mesh, 'linear', hr_shape)
    input_field = tf.placeholder(data.dtype, [
        batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x,
        nc // n_block_y, nc // n_block_z
    ])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape)
    linearop = mtf.assign(fieldvar, mtfinp)
    #
    field = fieldvar
    initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                          output_dtype=tf.float32,
                          output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                          name='my_dumb_reshape',
                          splittable_dims=part_shape[:-1] + hr_shape[:4])

    #
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=initc.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init(low,
                               high,
                               0.1,
                               kv_lr,
                               kv_hr,
                               halo_size,
                               hr_shape,
                               lr_shape,
                               part_shape[1:],
                               downsampling_factor=downsampling_factor,
                               antialias=True)

        final_state = mtfpm.nbody(state,
                                  stages,
                                  lr_shape,
                                  hr_shape,
                                  kv_lr,
                                  kv_hr,
                                  halo_size,
                                  downsampling_factor=downsampling_factor)
    else:
        final_state = mtfpm.lpt_init(low,
                                     high,
                                     stages[-1],
                                     kv_lr,
                                     kv_hr,
                                     halo_size,
                                     hr_shape,
                                     lr_shape,
                                     part_shape[1:],
                                     downsampling_factor=downsampling_factor,
                                     antialias=True)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv_pr]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.placeholder(tf.float32, shape=())

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype)
    var_grads = [
        mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype)
    ]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0