Beispiel #1
0
def lpt_init_single(lr_field, a0, kvec_lr, halo_size, lr_shape, hr_shape, part_shape, antialias=True, order=1, post_filtering=True, cosmology=Planck15):
  a = a0
  batch_dim = lr_field.shape[0]
  lnc = lr_shape[-1].size

  # Create particles on the high resolution grid
  mstate = mesh_ops.mtf_indices(lr_field.mesh, shape=part_shape, dtype=tf.float32)
  X = mtf.einsum([mtf.ones(lr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:])


  k_dims_lr = [d.shape[0] for d in kvec_lr]
  k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]

  lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)

  grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)

  # Reorder the low res FFTs which where transposed# y,z,x
  grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]]


  displacement = []
  for f in grad_kfield_lr:
    f = mesh_utils.c2r3d(f, lr_shape[-3:])
    f = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                      [f],
                      output_dtype=tf.float32,
                      output_shape=mtf.Shape(hr_shape[0:4]+[
                        mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                        mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                        mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                      name='my_reshape',
                      splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3])

    for block_size_dim in hr_shape[-3:]:
      f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
      f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size)
    d =  mesh_utils.cic_readout(f, X, halo_size)
    displacement.append(d)
  # Readout to particle positions
  displacement = mtf.stack([ d for d in displacement],"ndim",axis=4)

  pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0)
  DX = pt.D1(a) * displacement
  P = (a ** 2 * pt.f1(a) * pt.E(a)) * DX
  F = (a ** 2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX
  # TODO: Implement 2nd order LPT

  # Moves the particles according to displacement
  X = X + DX

  return X, P, F
Beispiel #2
0
    def _get_decoder_inputs(self, context):
        """Computes the inputs to the decoder when using transparent attention.

    We must cache on the context in order to ensure that we are not replicating
    variables when the layer's call function is called in different tf variable
    scopes.

    Args:
      context: a Context

    Returns:
      a list containing `self.num_decoder_modules` of tensors with shape
        [<batch_dims>, length_dim, output_vocab_dim]
    """
        if hasattr(context, "decoder_layers_per_module"):
            return context.decoder_layers_per_module

        encoder_layer_outputs = [
            mtf.layers.rename_length_to_memory_length(output)
            for output in context.encoder_layer_outputs
        ]

        layers_per_module = self.layers_per_encoder_module
        encoder_module_outputs_dim = mtf.Dimension(
            "encoder_module_outputs", size=self.encoder_num_modules + 1)
        decoder_module_inputs_dim = mtf.Dimension(
            "decoder_module_inputs", size=self.decoder_num_modules)
        encoder_module_outputs = mtf.stack(
            [encoder_layer_outputs[0]] +
            encoder_layer_outputs[layers_per_module::layers_per_module],
            dim_name="encoder_module_outputs")
        w = mtf.get_variable(
            context.mesh,
            "w",
            mtf.Shape([encoder_module_outputs_dim, decoder_module_inputs_dim]),
            initializer=tf.random_normal_initializer(
                stddev=(encoder_module_outputs_dim.size *
                        decoder_module_inputs_dim.size)**-0.5),
            dtype=context.variable_dtype)
        if context.train and self.dropout_rate != 0.0:
            w = mtf.dropout(w, 1.0 - self.dropout_rate)
        s = mtf.softmax(w, reduced_dim=encoder_module_outputs_dim)
        z = mtf.einsum([s, encoder_module_outputs],
                       reduced_dims=[encoder_module_outputs_dim])
        input_per_decoder = mtf.split(
            z,
            split_dim=decoder_module_inputs_dim,
            num_or_size_splits=decoder_module_inputs_dim.size)
        context.decoder_layers_per_module = [
            mtf.reshape(inpt, z.shape.dims[1:]) for inpt in input_per_decoder
        ]
        return context.decoder_layers_per_module
Beispiel #3
0
    def compute_loss(self, decoder: transformer.Unitransformer,
                     hidden: mtf.Tensor, targets: mtf.Tensor,
                     context: transformer.Context) -> mtf.Tensor:
        """Returns the loss without computing a softmax over the entire vocab."""
        loss = 0
        tail_cluster_masks = []
        for cluster in self._tail_clusters:
            cluster_mask = cluster.get_cluster_mask(targets)
            tail_cluster_masks.append(cluster_mask)

            if cluster.length_projection_factor == 1:
                targets_in_cluster = mtf.where(cluster_mask, targets, 0)
                hidden_in_cluster = mtf.where(cluster_mask, hidden, 0)
            else:
                # TODO(mmatena): Unfold the batch dim to get a super long sequence dim
                # to reduce the risk of overflowing the projection.
                proj_to_cluster_len = cluster.get_project_to_cluster_length(
                    cluster_mask, dtype=targets.dtype)
                targets_in_cluster = mtf.einsum(
                    [proj_to_cluster_len, targets],
                    reduced_dims=[targets.shape.get_dim_by_name("length")])
                hidden_in_cluster = mtf.einsum(
                    [mtf.cast(proj_to_cluster_len, hidden.dtype), hidden],
                    reduced_dims=[hidden.shape.get_dim_by_name("length")])

            loss += cluster.compute_loss(decoder, hidden_in_cluster,
                                         targets_in_cluster, context)

        tail_clusters_dim = mtf.Dimension("tail_clusters",
                                          len(tail_cluster_masks))
        tail_node_targets = mtf.reduce_sum(
            mtf.stack([(self._head_cluster.end_token_id + i) *
                       mtf.cast(mask, targets.dtype)
                       for i, mask in enumerate(tail_cluster_masks)],
                      tail_clusters_dim.name),
            reduced_dim=tail_clusters_dim)
        head_targets = mtf.where(mtf.cast(tail_node_targets, tf.bool),
                                 tail_node_targets, targets)
        loss += self._head_cluster.compute_loss(decoder, hidden, head_targets,
                                                context)

        return loss
Beispiel #4
0
def force_single(state,
                 lr_shape,
                 hr_shape,
                 kvec_lr,
                 halo_size,
                 cosmology=Planck15,
                 pm_nc_factor=1,
                 **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        field = mtf.slice(field, halo_size, block_size_dim.size,
                          block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)

    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]

    displacement = []
    for f in kfield_lr:
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size)
        d = mesh_utils.cic_readout(f, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F
Beispiel #5
0
def force(state,
          lr_shape,
          hr_shape,
          kvec_lr,
          kvec_hr,
          halo_size,
          cosmology=Planck15,
          downsampling_factor=2,
          pm_nc_factor=1,
          antialias=True,
          **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape
    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_hr = [d.shape[0] for d in kvec_hr]
    # Reorder the FFTs which where transposed# y,z,x
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]]

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)

    # Split the field into low and high resolution
    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)
    low = mtf.reshape(low, low.shape[:-1])
    hr_field = mtf.reshape(high, high.shape[:-1])
    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)
    hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr)

    kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield,
                                                    kvec_lr,
                                                    r_split=0)
    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)
    kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield,
                                                    kvec_hr,
                                                    r_split=0)
    kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]
    kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]]

    displacement = []
    for f, g in zip(kfield_lr, kfield_hr):
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])
        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [
                halo_size // 2**downsampling_factor,
                halo_size // 2**downsampling_factor
            ], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim,
                                     halo_size // 2**downsampling_factor)
        f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)])
        f = mesh_utils.upsample(f, downsampling_factor)
        f = mtf.reshape(f, f.shape[:-1])

        g = mesh_utils.c2r3d(g, f.shape[-3:])
        high_shape = g.shape
        # And now we remove the large scales
        g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)])
        _low = mesh_utils.downsample(g,
                                     downsampling_factor,
                                     antialias=antialias)
        g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor),
                            g.shape)
        g = mtf.reshape(g, high_shape)

        d = mesh_utils.cic_readout(f + g, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F