예제 #1
0
    def get_indices(self, keys: mtf.Tensor,
                    query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]:
        """Generate score and indices for the query."""
        score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3])
        scores = mtf.einsum([query, keys],
                            output_shape=score_shape)  # [b, l, h, 2, n_keys]
        knn_dim = mtf.Dimension("knn", self.knn)
        scores, indices = mtf.top_k(scores, score_shape.dims[-1],
                                    knn_dim)  # [b, l, h, 2, knn]

        # Computes the top cartesian products and their indices
        knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2)
        scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2])
        scores2 = mtf.rename_dimension(scores2, "knn", "knn2")
        out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:])
        all_scores = mtf.add(scores1, scores2, output_shape=out_shape)
        all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:],
                                            knn_square_dim)

        indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2])
        indices1 = mtf.multiply(indices1, self.n_keys)
        indices2 = mtf.rename_dimension(indices2, "knn", "knn2")
        all_indices = mtf.add(indices1, indices2, output_shape=out_shape)
        all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:],
                                             knn_square_dim)

        scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1],
                                         knn_dim)
        return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
예제 #2
0
  def testWhileLoopOperation(self):
    # This test case implements the following:
    # for i in range(10):
    #   x = x * 2
    i = mtf.constant(self.mesh, 0, mtf.Shape([]))
    cond_fn = lambda i, x: mtf.less(i, 10)
    body_fn = lambda i, x: [mtf.add(i, 1), mtf.multiply(x, 2)]

    while_loop_operation = mtf.WhileLoopOperation(cond_fn, body_fn, [i, self.x])
    self.assertEqual(while_loop_operation.splittable_dims,
                     frozenset(["a", "b"]))
    self.assertEqual(while_loop_operation.unsplittable_dims, frozenset())
예제 #3
0
def recon_model(mesh,
                datasm,
                rsdfactor,
                M0,
                R0,
                width,
                off,
                istd,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]

    #
    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))

    ##
    state = mtfpm.lpt_init_single(
        fieldvar,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr,
                                     halo_size)

    final_field = mtf.zeros(mesh, shape=part_shape)
    final_field = mcomp.cic_paint_fr(final_field,
                                     final_state,
                                     output_shape=part_shape,
                                     hr_shape=hr_shape,
                                     halo_size=halo_size,
                                     splittables=splittables,
                                     mesh=mesh)

    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(tf.constant(width) * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    ##RSD below
    hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables,
                              mesh)
    mstate = mpm.mtf_indices(hr_field.mesh,
                             shape=part_shape[1:],
                             dtype=tf.float32)
    X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate],
                   output_shape=[batch_dim] + mstate.shape[:])

    massf = mesh_utils.r2c3d(final_field, k_dims, dtype=cdtype)
    masssmf = mtf.cwise(cwise_fingauss,
                        [massf, float_to_mtf(R1, mesh, scalar)] + kv +
                        [tfnc, tfbs],
                        output_dtype=cdtype)
    masssm = mesh_utils.c2r3d(masssmf, final_field.shape[-3:], dtype=dtype)
    masssm = masssm + 1e-5
    imasssm = mtf.pow(x, -1.)

    vzweights = final_state[1]
    vzweights = mtf.slicewise(lambda x: x[:, :, :, :, -1], [vzweights],
                              output_dtype=tf.float32,
                              output_shape=vzweights.shape[:-1],
                              name='get_vz',
                              splittable_dims=vzweights.shape[1:-1])
    print("weights : ", vzweights)

    momz = mtf.zeros(mesh, shape=part_shape)
    momz = mcomp.cic_paint_fr(final_field, final_state, output_shape=part_shape, hr_shape=hr_shape, \
                              halo_size=halo_size, splittables=splittables, mesh=mesh, weights=vzweights)
    momzf = mesh_utils.r2c3d(momz, k_dims, dtype=cdtype)
    momzsmf = mtf.cwise(cwise_fingauss,
                        [momzf, float_to_mtf(R1, mesh, scalar)] + kv +
                        [tfnc, tfbs],
                        output_dtype=cdtype)
    momzsm = mesh_utils.c2r3d(momzsmf, momz.shape[-3:], dtype=dtype)

    #Shift
    velzsm = mtf.divide(momzsm, masssm)
    vz = mcomp.cic_readout_fr(velzsm, [X],
                              hr_shape=hr_shape,
                              halo_size=halo_size,
                              splittables=splittables,
                              mesh=mesh)
    vz = mtf.multiply(vz, rsdfactor)
    print("vz : ", vz)

    Xrsd = mtf.slicewise(lambda x, vz: x + tf.stack(
        [tf.zeros_like(vz), tf.zeros_like(vz), vz], 4), [X, vzweights],
                         output_dtype=tf.float32,
                         output_shape=X.shape,
                         name='add_vz',
                         splittable_dims=X.shape[1:-1])
    print(Xrsd)
    modelread = mcomp.cic_readout_fr(model, [X],
                                     hr_shape=hr_shape,
                                     halo_size=halo_size,
                                     splittables=splittables,
                                     mesh=mesh)
    modelrsd = mtf.zeros(mesh, shape=part_shape)
    modelrsd = mcomp.cic_paint_fr(modelrsd, [Xrsd], output_shape=part_shape, hr_shape=hr_shape, \
                                  halo_size=halo_size, splittables=splittables, mesh=mesh, weights=modelread)

    model = modelrsd
    print(modelrsd)

    #Likelihood and prior here
    mtfdatasm = mtf.import_tf_tensor(mesh,
                                     tf.convert_to_tensor(datasm),
                                     shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    M0 = tf.constant(M0)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdatasm + M0)
    if off is not None:
        mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
        diff = diff + mtfoff
    if istd is not None:
        mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
        diff = (diff + mtfoff
                ) * mtfistd  #For some reason, doing things wrong this one
    else:
        diff = diff / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    fields = [fieldvar, final_field, model]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv