示例#1
0
  def update_context(self, context, x, pool_dim_name):
    """Update the length dimension, sequence_id and position information."""

    pooled_seq_length = x.shape.get_dim_by_name(pool_dim_name).size
    # For position, we slice the first `pooled_seq_length` indices instead of
    # striding. This ensures that the 3rd position before the pooling becomes
    # 2nd position after pooling instead of remembering its position before
    # pooling.
    new_context_position = mtf.slice(
        context.position,
        begin=0,
        size=pooled_seq_length,
        slice_dim_name=pool_dim_name)
    context.position = new_context_position

    pooled_seq_length = x.shape.get_dim_by_name(pool_dim_name).size
    new_length_dim = mtf.Dimension(
        name=pool_dim_name, size=pooled_seq_length)

    new_sequence_id = mtf.stride_tensor_1d(
        context.sequence_id,
        pool_dim=context.length_dim,
        pool_size=self.pooling_size)

    context.length_dim = new_length_dim
    context.sequence_id = new_sequence_id
示例#2
0
def halo_reduce(x, blocks_dim, block_size_dim, halo_size, wrap=True):
    """Reduce each block with the margins of adjacent blocks.

  Get left and right blocks_dim and sum overlap along block_size_dim.
  Only supports halo size smaller than block_size/2

  Args:
    x: a Tensor.
    blocks_dim: a Dimension in x.shape
    block_size_dim: a Dimension in x.shape
    halo_size: an integer
    wrap: a boolean

  Returns:
    a Tensor with the same shape as x, other than in block_size_dim, whose
    size is increased by 2*halo_size.
  """
    if halo_size == 0:
        return x
    block_size = block_size_dim.size
    assert halo_size <= block_size // 2

    left_margin = mtf.slice(x, 0, 2 * halo_size, block_size_dim.name)
    right_margin = mtf.slice(x, block_size_dim.size - 2 * halo_size,
                             2 * halo_size, block_size_dim.name)
    center = mtf.slice(x, 2 * halo_size, block_size_dim.size - 4 * halo_size,
                       block_size_dim.name)

    # Perform halo exchange sum margins
    left = mtf.shift(right_margin, 1, blocks_dim, wrap) + left_margin
    right = mtf.shift(left_margin, -1, blocks_dim, wrap) + right_margin

    # Recompose block
    left = mtf.pad(left, [0, block_size_dim.size - 2 * halo_size],
                   block_size_dim.name)
    right = mtf.pad(right, [block_size_dim.size - 2 * halo_size, 0],
                    block_size_dim.name)
    center = mtf.pad(center, [2 * halo_size, 2 * halo_size],
                     block_size_dim.name)
    x = left + center + right
    return x
示例#3
0
    def call(self, context, x, losses=None):
        """Call the layer."""

        if self.canine_mode:
            # This is the canine-like ByT5 + LASC baseline in paper.
            return self.call_canine_encoder(context, x, losses=losses)

        if self.conv_type:
            if self.conv_type == "conv1d":
                tf.logging.info("Using 1d conv")
                tmp_output = mtf.Dimension("tmp_dim", x.shape[-1].size)
                orig_dim = x.shape[-1]
                x = mtf.layers.conv1d(x,
                                      tmp_output,
                                      filter_size=self.filter_size,
                                      stride=1)
                x = mtf.rename_dimension(x, "tmp_dim", orig_dim.name)
                tf.logging.info(x)
        if self.norm:
            x = sublayer_rms_norm(x, None, context)
        o = x
        olength = o.shape.get_dim_by_name("length")
        o = custom_attention.gradient_based_subword_tokenization(
            o,
            olength,
            downsample=self.downsample_query,
            use_offsets=self.use_offsets,
            consider_chars_as_blocks=self.consider_chars_as_blocks,
            use_block_pos_embedding=self.use_block_pos_embedding,
            memory_embeddings=self.num_memory_slots,
            context=context,
            block_mixing_mode=self.block_mixing_mode,
            activation=self.rank_activation,
            downsample_function=self.gbst_pool)
        new_length_dim = o.shape.get_dim_by_name("length")
        context.length_dim = new_length_dim
        new_context_position = context.get_position()
        context.position = new_context_position
        context.sequence_id = mtf.slice(context.sequence_id,
                                        begin=0,
                                        size=new_length_dim.size,
                                        slice_dim_name=new_length_dim.name)
        if self.use_ffn:
            # not actually used in Charformer.
            tf.logging.info("Using FFN")
            o2 = self.ffn.call(context, o)
            o = o + o2
        if self.norm:
            o = sublayer_rms_norm(o, None, context)
        olength = o.shape.get_dim_by_name("length")
        return o, context
示例#4
0
 def call_canine_encoder(self, context, x, losses=None):
     """Call Canine baseline encoder (Byte level T5 + LASC in paper)."""
     # local attention
     params = self.make_params(context)
     q = params.compute_q(x)
     if self.shared_kv:
         kv = params.compute_kv(x)
         k = kv
         v = kv
     else:
         k = params.compute_k(x)
         v = params.compute_v(x)
     # local attention
     output_shape = x.shape
     x = custom_attention.local_attention_1d(
         q,
         k,
         v,
         length_dim=context.length_dim,
         length_dim_num_splits=1,
         key_dim=self.kv_dim,
         value_dim=self.kv_dim,
         fully_autoregressive=False,
         radius=self.radius,
         sequence_id=context.sequence_id,
         write_priority=context.write_priority,
         read_priority=context.read_priority,
         context=context,
         attention_kwargs=self.attention_kwargs_from_context(context))
     o = params.compute_output(x, output_shape=output_shape)
     # strided convolutions
     tmp_output = mtf.Dimension("tmp_dim", o.shape[-1].size)
     # downsample query args is reused here for "r"
     o = mtf.layers.conv1d(o,
                           tmp_output,
                           filter_size=self.filter_size,
                           stride=int(self.downsample_query))
     o = mtf.rename_dimension(o, "tmp_dim", "d_model")
     tf.logging.info(o)
     new_length_dim = o.shape.get_dim_by_name("length")
     context.length_dim = new_length_dim
     new_context_position = context.get_position()
     context.position = new_context_position
     context.sequence_id = mtf.slice(context.sequence_id,
                                     begin=0,
                                     size=new_length_dim.size,
                                     slice_dim_name=new_length_dim.name)
     return o, context
示例#5
0
def downsample_hr_to_lr(field, lr_shape, hr_shape, downsampling_factor, halo_size, splittables, mesh):
    # Reshaping array into high resolution mesh
    field = mtf.reshape(field, field.shape+[mtf.Dimension('h_dim', 1)])
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)
    low = mtf.reshape(low, low.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size//2**downsampling_factor, block_size_dim.size//2**downsampling_factor, block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:,0,0,0],
                        [low],
                        output_dtype=field.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=splittables)

    return low
示例#6
0
def cic_paint_fr(field, state, output_shape, hr_shape, halo_size, splittables, mesh, weights=None):
    '''paint the position from state to a field of batch+3D tensor
    Ops performed :
    - reshape to hr_shape
    - pad
    - paint
    - halo_reduce
    - slice to remove pad
    - reshape to output_shape
    '''
    lnc = field.shape[-1].size
    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
                          [field],
                          output_dtype=tf.float32,
                          output_shape=mtf.Shape(hr_shape[0:4]+[
                              mtf.Dimension('sx_block', lnc//hr_shape[1].size),
                              mtf.Dimension('sy_block', lnc//hr_shape[2].size),
                              mtf.Dimension('sz_block', lnc//hr_shape[3].size)]),
                          name='my_reshape',
                          splittable_dims=splittables)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    field = mesh_utils.cic_paint(field, state[0], halo_size, weights)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        field = mtf.slice(field, halo_size, block_size_dim.size, block_size_dim.name)

    field = mtf.slicewise(lambda x: x[:,0,0,0],
                        [field],
                          output_dtype=field.dtype,
                          output_shape=output_shape,
                          name='my_dumb_reshape',
                          splittable_dims=splittables)
    return field
def synthetic_attention(q,
                        k,
                        v,
                        memory_length_dim,
                        key_dim,
                        value_dim,
                        bias=None,
                        dropout_rate=0.0,
                        dropout_broadcast_dims=None,
                        extra_logit=None,
                        synthesize=True,
                        synthesize_mode="random_plus_alpha",
                        factorized_dim=16,
                        max_length=512,
                        context=None):
  """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743).

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    synthesize: flag to use synthetic attention or not
    synthesize_mode: which variant of synthesizer to use
    factorized_dim: factorized dim for synthesizers
    max_length: max length of input sequence
    context: context since we need context mode

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """

  if synthesize:
    num_heads = v.shape.get_dim_by_name("heads")
    tf.logging.info("Using synthesizer")
    if synthesize_mode == "random":
      tf.logging.info("Using Random Synthesizers")
      r_shape = mtf.Shape([mtf.Dimension("length", max_length),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", max_length)])
      r = mtf.get_variable(context.mesh, "R", r_shape,
                           initializer=None,
                           dtype=context.variable_dtype)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      logits = r
      r_shape = logits.shape
    elif synthesize_mode == "factorized":
      tf.logging.info("Using Factorized Random Synthesizers")
      k = factorized_dim
      r1_shape = mtf.Shape([mtf.Dimension("tmp", k),
                            mtf.Dimension("heads", num_heads.size),
                            mtf.Dimension("memory_length", 512)])
      r2_shape = mtf.Shape([mtf.Dimension("tmp", k),
                            mtf.Dimension("heads", num_heads.size),
                            mtf.Dimension("memory_length", 512)])
      r_shape = mtf.Shape([mtf.Dimension("length", 512),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", 512)])
      r1 = mtf.get_variable(context.mesh, "R1", r1_shape,
                            initializer=None,
                            dtype=context.variable_dtype)
      r2 = mtf.get_variable(context.mesh, "R2", r2_shape,
                            initializer=None,
                            dtype=context.variable_dtype)
      r = mtf.einsum([r1, r2], r_shape)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      logits = r
    elif synthesize_mode == "dense_minus":
      # Dense Synthesizer Model
      tmp_dim = mtf.Dimension("memory_length", max_length)
      logits = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                                use_bias=False,
                                name="pi",
                                reduced_dims=[key_dim],
                                variable_dtype=None)
      logits = mtf.slice(logits, 0, memory_length_dim.size,
                         memory_length_dim.name)
      if context.mode == "incremental":
        pass
      else:
        length_dim = q.shape.get_dim_by_name("length")
        logits = mtf.slice(logits, 0, length_dim.size, "length")
    elif synthesize_mode == "random_plus_alpha" or \
        synthesize_mode == "random_plus":
      # Mixture Random Synthesizer with learnable Alpha
      tf.logging.info("Using Random Plus Alpha")
      logits = mtf.einsum([q, k], reduced_dims=[key_dim])
      num_heads = logits.shape.get_dim_by_name("heads")
      r_shape = mtf.Shape([mtf.Dimension("length", 512),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", 512)])
      r = mtf.get_variable(context.mesh, "R", r_shape,
                           initializer=None,
                           dtype=context.variable_dtype)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, length_dim.name)
      if "alpha" in synthesize_mode:
        alpha = mtf.get_variable(context.mesh,
                                 "alpha",
                                 mtf.Shape([mtf.Dimension("alpha", 1)]),
                                 initializer=tf.zeros_initializer(),
                                 dtype=context.variable_dtype)
        alpha = mtf.sigmoid(alpha)
        logits = ((1-alpha) * logits) + (alpha * r)
      else:
        logits = logits + r
    elif synthesize_mode == "dense_plus_alpha" or \
        synthesize_mode == "dense_plus":
      # Mixture Dense Synthesizer with learnable alpha
      tf.logging.info("Using Dense Plus Alpha Scaling")
      logits = mtf.einsum([q, k], reduced_dims=[key_dim])
      tmp_dim = mtf.Dimension("memory_length", 512)
      r = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                           use_bias=False,
                           name="pi",
                           reduced_dims=[key_dim],
                           variable_dtype=None)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        pass
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      if "alpha" in synthesize_mode:
        alpha = mtf.get_variable(context.mesh,
                                 "alpha",
                                 mtf.Shape([mtf.Dimension("alpha", 1)]),
                                 initializer=tf.zeros_initializer(),
                                 dtype=context.variable_dtype)
        alpha = mtf.sigmoid(alpha)
        logits = ((1-alpha) * logits) + (alpha * r)
      else:
        logits = logits + r
  if bias is not None:
    logits += bias

  weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
  weights = mtf.dropout(
      weights, context.train, 1.0 - dropout_rate,
      noise_shape=weights.shape - dropout_broadcast_dims)

  if synthesize and "plus" not in synthesize_mode:
    if synthesize_mode == "dense_minus":
      outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim])
    else:
      outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim])
  else:
    outputs_shape = q.shape - [key_dim] + value_dim

  outputs = mtf.einsum([weights, v], outputs_shape)
  return outputs
示例#8
0
def nbody_prototype(mesh,
                    infield=False,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin, shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    ## Compute initial initial conditions distributed
    input_field = tf.placeholder(dtype, [batch_size, nc, nc, nc])
    if infield:
        initc = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    else:
        initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            initc,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return initc, final_field, input_field
示例#9
0
文件: unet.py 项目: bruinxiong/mesh-1
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
    """Builds the UNet model graph, train op and eval metrics.

  Args:
    mesh: a MeshTensorflow.mesh object.
    mesh_impl: a mesh implementation, such as SimdMeshImpl and
      PlacementMeshImpl.
    dataset_str: a string of either train or eval. This is used for batch_norm.
    images: a laid out Tensor with shape [batch, x, y, num_channels]
      or [batch, x, y, z, num_channels].
    labels: a laid out Tensor with shape [batch, x, y, num_classes]
      or [batch, x, y, z, num_classes].

  Returns:
    Prediction and loss.
  """

    is_training = (dataset_str == 'train')
    if dataset_str == 'train':
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train)
    else:
        assert dataset_str == 'eval'
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval)
    image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block)
    image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block)
    image_sx_dim = mtf.Dimension('image_sx_block',
                                 FLAGS.ct_resolution // FLAGS.image_nx_block)
    image_sy_dim = mtf.Dimension('image_sy_block',
                                 FLAGS.ct_resolution // FLAGS.image_ny_block)
    image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution)
    image_c_dim = mtf.Dimension('image_c', FLAGS.image_c)
    label_c_dim = mtf.Dimension('label_c', FLAGS.label_c)
    mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str)

    mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype)
    variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype)

    # Import input features.
    x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images),
                                   mtf_images_shape)
    x = mtf.cast(x, mtf_dtype)

    # Import ground truth labels.
    t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels),
                                   mtf_labels_shape)
    t = mtf.cast(t, mtf_dtype)

    # Transpose the blocks.
    if FLAGS.sampled_2d_slices:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            label_c_dim
        ])
    else:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, label_c_dim
        ])

    # Network.
    levels = []
    all_bn_update_ops = []
    # add levels with convolution or down-sampling
    for depth in range(FLAGS.network_depth):
        for n_conv in range(FLAGS.n_conv_per_block):
            if depth == 0 and n_conv == 0:
                # no dropout in 1st layer.
                dropout_keep_p = 1.0
            else:
                dropout_keep_p = FLAGS.dropout_keep_p
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_down_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)
        levels.append(x)

        if depth < FLAGS.network_depth - 1:
            if FLAGS.sampled_2d_slices:
                x = mtf.layers.max_pool2d(x, ksize=(2, 2))
            else:
                x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2))

    # add levels with up-convolution or up-sampling
    for depth in range(FLAGS.network_depth - 1)[::-1]:
        x = deconv_with_spatial_partition(
            x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
            FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
            'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
            variable_dtype, 'deconv_{}_0'.format(depth))
        x = mtf.concat([x, levels[depth]],
                       concat_dim_name='conv_{}_{}'.format(
                           depth, FLAGS.n_conv_per_block - 1))

        for n_conv in range(FLAGS.n_conv_per_block):
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_up_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)

    # no dropout in the final layer.
    if FLAGS.sampled_2d_slices:
        y = mtf.layers.conv2d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1),
            strides=(1, 1),
            padding='SAME',
            h_blocks_dim=image_nx_dim,
            w_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )
    else:
        y = mtf.layers.conv3d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1, 1),
            strides=(1, 1, 1),
            padding='SAME',
            d_blocks_dim=image_nx_dim,
            h_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )

    # use mtf.constant to make sure there is no CPU-side constants.
    def scalar(v, dtype):
        return mtf.constant(mesh, v, shape=[], dtype=dtype)

    argmax_t = mtf.argmax(t, label_c_dim)
    liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype)
    lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype)

    argmax_y = mtf.argmax(y, label_c_dim)
    lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype)

    # summary of class ratios.
    lesion_pred_ratio = mtf.reduce_mean(lesion_y)
    lesion_label_ratio = mtf.reduce_mean(lesion_t)

    # summary of accuracy.
    accuracy = mtf.reduce_mean(
        mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype))

    # Cross-entropy loss. Up-weight the liver region.
    pixel_loss = mtf.layers.softmax_cross_entropy_with_logits(
        y, t, label_c_dim)
    pixel_weight = scalar(1, mtf_dtype) + \
        liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \
        lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight,
                          mtf_dtype)
    loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight)

    # Dice loss
    y_prob = mtf.softmax(y, reduced_dim=label_c_dim)
    lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'),
                                 reduced_dim=mtf.Dimension('label_c', 1))
    prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t,
                                    output_shape=mtf.Shape([batch_dim]))
    prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t,
                                   output_shape=mtf.Shape([batch_dim]))
    loss_dice_per_case = mtf.reduce_mean(
        scalar(-2, mtf_dtype) * prob_intersect /
        (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype)))
    loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum(
        prob_intersect) / (mtf.reduce_sum(prob_area_sum) +
                           scalar(FLAGS.dice_epsilon, mtf_dtype))

    loss_dice = (loss_dice_per_case + loss_dice_global) * scalar(
        0.5, mtf_dtype)

    loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar(
        1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen

    intersect = mtf.reduce_sum(lesion_y * lesion_t,
                               output_shape=mtf.Shape([batch_dim]))
    area_sum = mtf.reduce_sum(lesion_y + lesion_t,
                              output_shape=mtf.Shape([batch_dim]))
    # summary of dice.
    dice_per_case = mtf.reduce_mean(
        scalar(2, mtf_dtype) * intersect /
        (area_sum + scalar(0.000001, mtf_dtype)))
    dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / (
        mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype))

    eval_metrics = {
        'lesion_pred_ratio': lesion_pred_ratio,
        'lesion_label_ratio': lesion_label_ratio,
        'accuracy_of_all_classes': accuracy,
        'lesion_dice_per_case': dice_per_case,
        'lesion_dice_global': dice_global,
        'loss_xen': loss_xen,
        'loss_dice': loss_dice,
        'loss_dice_per_case': loss_dice_per_case,
        'loss_dice_global': loss_dice_global,
    }

    if FLAGS.sampled_2d_slices:
        y_prob_downsampled = mtf.layers.avg_pool2d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 2)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool2d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 2)
    else:
        y_prob_downsampled = mtf.layers.avg_pool3d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 3)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool3d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 3)

    liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c')
    lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c')
    preds = [
        mtf.reduce_sum(liver_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1)),
        mtf.reduce_sum(lesion_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1))
    ]

    if FLAGS.output_ground_truth:
        preds.append(
            mtf.reduce_sum(lesion_gt_downsampled,
                           reduced_dim=mtf.Dimension('label_c', 1)))

    preds.extend([intersect, area_sum])

    return preds, loss, eval_metrics, all_bn_update_ops
示例#10
0
def nbody_fn(mesh,
             klin,
             plin,
             nc=FLAGS.nc,
             bs=FLAGS.box_size,
             batch_size=FLAGS.batch_size,
             a0=FLAGS.a0,
             a=FLAGS.af,
             nsteps=FLAGS.nsteps,
             dtype=tf.float32):
  """ Pyramid N-body function
  """
  stages = np.linspace(a0, a, nsteps, endpoint=True)

  # Define the named dimensions
  # Parameters of the small scales decomposition
  n_block_x = FLAGS.nx
  n_block_y = FLAGS.ny
  n_block_z = 1
  halo_size = FLAGS.hsize

  # Parameters of the large scales decomposition
  downsampling_factor = FLAGS.dsample
  lnc = nc // 2**downsampling_factor

  fx_dim = mtf.Dimension("nx", nc)
  fy_dim = mtf.Dimension("ny", nc)
  fz_dim = mtf.Dimension("nz", nc)

  tfx_dim = mtf.Dimension("tx", nc)
  tfy_dim = mtf.Dimension("ty", nc)
  tfz_dim = mtf.Dimension("tz", nc)

  # Dimensions of the low resolution grid
  x_dim = mtf.Dimension("nx_lr", lnc)
  y_dim = mtf.Dimension("ny_lr", lnc)
  z_dim = mtf.Dimension("nz_lr", lnc)

  tx_dim = mtf.Dimension("tx_lr", lnc)
  ty_dim = mtf.Dimension("ty_lr", lnc)
  tz_dim = mtf.Dimension("tz_lr", lnc)

  nx_dim = mtf.Dimension('nx_block', n_block_x)
  ny_dim = mtf.Dimension('ny_block', n_block_y)
  nz_dim = mtf.Dimension('nz_block', n_block_z)

  sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
  sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
  sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

  batch_dim = mtf.Dimension("batch", batch_size)
  pk_dim = mtf.Dimension("npk", len(plin))
  pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  kx = mtf.import_tf_tensor(
      mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim])
  ky = mtf.import_tf_tensor(
      mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim])
  kz = mtf.import_tf_tensor(
      mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim])
  kv = [ky, kz, kx]

  # kvec for low resolution grid
  kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False)

  kx_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[0].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[tx_dim])
  ky_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[1].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[ty_dim])
  kz_lr = mtf.import_tf_tensor(
      mesh,
      kvec_lr[2].squeeze().astype('float32') / 2**downsampling_factor,
      shape=[tz_dim])
  kv_lr = [ky_lr, kz_lr, kx_lr]

  # kvec for high resolution blocks
  padded_sx_dim = mtf.Dimension('padded_sx_block',
                                nc // n_block_x + 2 * halo_size)
  padded_sy_dim = mtf.Dimension('padded_sy_block',
                                nc // n_block_y + 2 * halo_size)
  padded_sz_dim = mtf.Dimension('padded_sz_block',
                                nc // n_block_z + 2 * halo_size)
  kvec_hr = flowpm.kernels.fftk([
      nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
      nc // n_block_z + 2 * halo_size
  ],
                                symmetric=False)

  kx_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[0].squeeze().astype('float32'), shape=[padded_sx_dim])
  ky_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[1].squeeze().astype('float32'), shape=[padded_sy_dim])
  kz_hr = mtf.import_tf_tensor(
      mesh, kvec_hr[2].squeeze().astype('float32'), shape=[padded_sz_dim])
  kv_hr = [ky_hr, kz_hr, kx_hr]

  shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  lr_shape = [batch_dim, x_dim, y_dim, z_dim]
  hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
  part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

  # Compute initial initial conditions distributed
  initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

  # Reshaping array into high resolution mesh
  field = mtf.slicewise(
      lambda x: tf.expand_dims(
          tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [initc],
      output_dtype=tf.float32,
      output_shape=hr_shape,
      name='my_reshape',
      splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

  for block_size_dim in hr_shape[-3:]:
    field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

  for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
    field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

  field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
  high = field
  low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

  low = mtf.reshape(low, low.shape[:-1])
  high = mtf.reshape(high, high.shape[:-1])

  for block_size_dim in hr_shape[-3:]:
    low = mtf.slice(low, halo_size // 2**downsampling_factor,
                    block_size_dim.size // 2**downsampling_factor,
                    block_size_dim.name)
  # Hack usisng  custom reshape because mesh is pretty dumb
  low = mtf.slicewise(
      lambda x: x[:, 0, 0, 0], [low],
      output_dtype=tf.float32,
      output_shape=lr_shape,
      name='my_dumb_reshape',
      splittable_dims=lr_shape[:-1] + hr_shape[:4])

  state = mtfpm.lpt_init(
      low,
      high,
      0.1,
      kv_lr,
      kv_hr,
      halo_size,
      hr_shape,
      lr_shape,
      part_shape[1:],
      downsampling_factor=downsampling_factor,
      antialias=True,
  )

  final_state = mtfpm.nbody(
      state,
      stages,
      lr_shape,
      hr_shape,
      kv_lr,
      kv_hr,
      halo_size,
      downsampling_factor=downsampling_factor)

  # paint the field
  final_field = mtf.zeros(mesh, shape=hr_shape)
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.pad(final_field, [halo_size, halo_size],
                          block_size_dim.name)
  final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
  # Halo exchange
  for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]):
    final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                  halo_size)
  # Remove borders
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                            block_size_dim.name)

  final_field = mtf.slicewise(
      lambda x: x[:, 0, 0, 0], [final_field],
      output_dtype=tf.float32,
      output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
      name='my_dumb_reshape',
      splittable_dims=part_shape[:-1] + hr_shape[:4])

  return initc, final_field
示例#11
0
def benchmark_model(mesh):
  """
  Initializes a 3D volume with random noise, and execute a forward FFT
  """
  # Setup parameters
  bs = FLAGS.box_size
  nc = FLAGS.cube_size
  batch_size = FLAGS.batch_size
  a0 = FLAGS.a0
  a = 1.0
  nsteps = FLAGS.pm_steps

  # Compute a few things first, using simple tensorflow
  klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
  plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
  ipklin = iuspline(klin, plin)
  stages = np.linspace(a0, a, nsteps, endpoint=True)

  # Initialize the integration steps
  stages = np.linspace(FLAGS.a0, 1.0, FLAGS.pm_steps, endpoint=True)

  # Generate a batch of 3D initial conditions
  initial_conditions = flowpm.linear_field(
      nc,  # size of the cube
      bs,  # Physical size of the cube
      ipklin,  # Initial power spectrum
      batch_size=batch_size)

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  from flowpm.kernels import laplace_kernel, gradient_kernel
  lap = tf.cast(laplace_kernel(kvec), tf.complex64)
  grad_x = gradient_kernel(kvec, 0)
  grad_y = gradient_kernel(kvec, 1)
  grad_z = gradient_kernel(kvec, 2)

  # Define the named dimensions
  # Parameters of the small scales decomposition
  n_block_x = 8
  n_block_y = 4
  n_block_z = 1
  halo_size = 4

  # Parameters of the large scales decomposition
  downsampling_factor = 2
  lnc = nc // 2**downsampling_factor

  fx_dim = mtf.Dimension("nx", nc)
  fy_dim = mtf.Dimension("ny", nc)
  fz_dim = mtf.Dimension("nz", nc)

  tfx_dim = mtf.Dimension("tx", nc)
  tfy_dim = mtf.Dimension("ty", nc)
  tfz_dim = mtf.Dimension("tz", nc)

  # Dimensions of the low resolution grid
  tx_dim = mtf.Dimension("tx_lr", nc)
  ty_dim = mtf.Dimension("ty_lr", nc)
  tz_dim = mtf.Dimension("tz_lr", nc)

  nx_dim = mtf.Dimension('nx_block', n_block_x)
  ny_dim = mtf.Dimension('ny_block', n_block_y)
  nz_dim = mtf.Dimension('nz_block', n_block_z)

  sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
  sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
  sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

  batch_dim = mtf.Dimension("batch", batch_size)
  pk_dim = mtf.Dimension("npk", len(plin))
  pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

  # Compute necessary Fourier kernels
  kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
  kx = mtf.import_tf_tensor(mesh,
                            kvec[0].squeeze().astype('float32'),
                            shape=[tfx_dim])
  ky = mtf.import_tf_tensor(mesh,
                            kvec[1].squeeze().astype('float32'),
                            shape=[tfy_dim])
  kz = mtf.import_tf_tensor(mesh,
                            kvec[2].squeeze().astype('float32'),
                            shape=[tfz_dim])
  kv = [ky, kz, kx]

  kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
  kx_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[0].squeeze().astype('float32'),
                               shape=[tx_dim])
  ky_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[1].squeeze().astype('float32'),
                               shape=[ty_dim])
  kz_lr = mtf.import_tf_tensor(mesh,
                               kvec_lr[2].squeeze().astype('float32'),
                               shape=[tz_dim])
  kv_lr = [ky_lr, kz_lr, kx_lr]

  # kvec for high resolution blocks
  shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
  hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
  part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

  initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)
  state = mtfpm.lpt_init_single(
      initc,
      a0,
      kv_lr,
      halo_size,
      lr_shape,
      hr_shape,
      part_shape[1:],
      antialias=True,
  )
  #state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape,
  #                       part_shape[1:], downsampling_factor=downsampling_factor, antialias=True,)

  # Here we can run our nbody
  final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

  # paint the field
  final_field = mtf.zeros(mesh, shape=hr_shape)
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.pad(final_field, [halo_size, halo_size],
                          block_size_dim.name)
  final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
  # Halo exchange
  for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]):
    final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                  halo_size)
  # Remove borders
  for block_size_dim in hr_shape[-3:]:
    final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                            block_size_dim.name)

  #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
  # Hack usisng  custom reshape because mesh is pretty dumb
  final_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [final_field],
                              output_dtype=tf.float32,
                              output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                              name='my_dumb_reshape',
                              splittable_dims=part_shape[:-1] + hr_shape[:4])

  return mtf.reduce_sum(final_field)
def lpt_prototype(mesh,
                  initial_conditions,
                  derivs,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    stages = np.linspace(a0, a, nsteps, endpoint=True)
    lap, grad_x, grad_y, grad_z = derivs
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition
    downsampling_factor = FLAGS.dsample
    lnc = nc // 2**downsampling_factor

    #

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32') /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)
    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False)

    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype('float32'),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype('float32'),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype('float32'),
                                 shape=[padded_sz_dim])
    kv_hr = [kx_hr, ky_hr, kz_hr]

    lr_shape = [batch_dim, x_dim, y_dim, z_dim]

    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]

    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    initc = tf.reshape(
        initial_conditions,
        [1, n_block_x, nc // n_block_x, n_block_y, nc // n_block_y, 1, nc])
    initc = tf.transpose(initc, [0, 1, 3, 5, 2, 4, 6])
    field = mtf.import_tf_tensor(mesh, initc, shape=hr_shape)

    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=tf.float32,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Hack to handle reshape acrosss multiple dimensions
    #low = mtf.reshape(low, [batch_dim, x_dim, low.shape[2], low.shape[5], z_dim])
    #low = mtf.reshape(low, lr_shape)

    state = mtfpm.lpt_init(
        low,
        high,
        a0,
        kv_lr,
        kv_hr,
        halo_size,
        hr_shape,
        lr_shape,
        k_dims,
        part_shape[1:],
        downsampling_factor=downsampling_factor,
        antialias=True,
    )

    # Here we can run our nbody
    final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return final_field
示例#13
0
  def __init__(self,
               config,
               is_training,
               input_ids,
               input_mask=None,
               token_type_ids=None,
               scope=None,
               mesh_shape="",
               layout=""):
    self.config = copy.deepcopy(config)
    del config
    if not is_training:
      self.config.layer_output_dropout_prob = 0.0
      self.config.attention_probs_dropout_prob = 0.0
      self.config.feedforward_intermediate_dropout_prob = 0.0
    input_shape = input_ids.shape
    assert input_shape.ndims == 2

    self._seq_dim = input_shape.dims[1]
    self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size)
    self._extra_losses = []
    mesh = input_ids.mesh

    if token_type_ids is None:
      token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32)

    with tf.variable_scope(scope, default_name="bert"):
      with tf.variable_scope("embeddings"):
        # Perform embedding lookup on the word ids.
        self.embedding_table = mtf.get_variable(
            mesh, "word_embeddings",
            mtf.Shape([self.vocab_dim, self.model_dim]),
            initializer=self.embedding_initializer)
        self.word_embedding_output = mtf.gather(
            self.embedding_table, input_ids, self.vocab_dim)

        # Add positional embeddings and token type embeddings, then layer
        # normalize and perform dropout.
        self.embedding_output = self.word_embedding_output

        token_type_table = mtf.get_variable(
            mesh, "token_type_embeddings",
            mtf.Shape([self.token_type_vocab_dim, self.model_dim]),
            initializer=self.embedding_initializer)
        if token_type_ids is not None:
          self.embedding_output += mtf.gather(
              token_type_table, token_type_ids, self.token_type_vocab_dim)
        if self.config.position_signal == "embedding":
          full_position_table = mtf.get_variable(
              mesh, "position_embeddings",
              mtf.Shape([self.max_position_embeddings_dim, self.model_dim]),
              initializer=self.embedding_initializer)
          short_position_table = mtf.rename_dimension(
              mtf.slice(full_position_table, 0, self.seq_dim.size,
                        self.max_position_embeddings_dim.name),
              self.max_position_embeddings_dim.name, self.seq_dim.name)
          self.embedding_output += short_position_table
        self.embedding_output = self.normalize(self.embedding_output)
        self.embedding_output = mtf.dropout(
            self.embedding_output, is_training,
            keep_prob=1.0 - self.config.layer_output_dropout_prob)

      with tf.variable_scope("encoder"):
        attention_biases = []
        if input_mask:
          # [batch_dim, memory_seq_dim]
          attention_biases.append(
              (1.0 - mtf.to_float(mtf.replace_dimensions(
                  input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0)
        if self.config.position_signal == "relative_attention_bias":
          buckets_dim = mtf.Dimension("buckets", 32)
          rp_bucket = _relative_position_bucket(
              mtf.range(mesh, self.memory_seq_dim, tf.int32)
              - mtf.range(mesh, self.seq_dim, tf.int32),
              num_buckets=buckets_dim.size)
          bias_var = mtf.get_variable(
              mesh, "relative_attention_bias",
              [self.num_heads_dim, buckets_dim],
              initializer=tf.zeros_initializer())
          attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim))
        attention_bias = mtf.add_n(attention_biases)
        prev_layer_output = self.embedding_output
        self.all_encoder_layers = []
        for block_num in range(self.config.num_blocks):
          with tf.variable_scope("block_%d" % block_num):
            for layer_idx, layer_type in enumerate(self.config.block_layers):
              layer_name = layer_type
              count = self.config.block_layers[:layer_idx].count(layer_type)
              if count:
                layer_name += "_%d" % count
              with tf.variable_scope(layer_name):
                x = prev_layer_output
                if self.config.residual_structure == "direct":
                  x = self.normalize(x)
                if layer_type == "attention":
                  x = self.self_attention(x, attention_bias)
                elif layer_type == "feedforward":
                  x = self.feedforward(x)
                elif layer_type == "moe":
                  x = self.moe(x, layout, mesh_shape, input_mask, is_training)
                else:
                  raise ValueError("unknown layer type " + layer_type)
                x = mtf.dropout(
                    x, is_training,
                    keep_prob=1.0 - self.config.layer_output_dropout_prob)
                layer_output = prev_layer_output + x
                if self.config.residual_structure == "original":
                  layer_output = self.normalize(layer_output)
                prev_layer_output = layer_output
          self.all_encoder_layers.append(layer_output)

      self.sequence_output = prev_layer_output
      if self.config.residual_structure == "direct":
        self.sequence_output = self.normalize(self.sequence_output)

      # The "pooler" converts the encoded sequence tensor of shape
      # [batch_dim, seq_dim, hidden_size] to a tensor of shape
      # [batch_dim, hidden_size]. This is necessary for segment-level
      # (or segment-pair-level) classification tasks where we need a fixed
      # dimensional representation of the segment.
      with tf.variable_scope("pooler"):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token. We assume that this has been pre-trained
        first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim)
        self.pooled_output = mtf.layers.dense(
            first_token_tensor,
            reduced_dims=[self.model_dim],
            new_dims=[self.model_dim],
            activation=mtf.tanh,
            kernel_initializer=self.dense_initializer,
            use_bias=self.config.use_bias)
示例#14
0
def lpt_prototype(mesh,
                  initial_conditions,
                  derivs,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    stages = np.linspace(a0, a, nsteps, endpoint=True)
    lap, grad_x, grad_y, grad_z = derivs
    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    # Parameters of the large scales decomposition
    downsampling_factor = 0
    lnc = nc // 2**downsampling_factor

    #

    ffx_dim = mtf.Dimension("fnx", nc)
    ffy_dim = mtf.Dimension("fny", nc)
    ffz_dim = mtf.Dimension("fnz", nc)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]

    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    field = mtf.import_tf_tensor(mesh, initial_conditions, shape=part_shape)

    state = mtfpm.lpt_init_single(
        field,
        a,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    print('TOTO', state)
    # Here we can run our nbody
    final_state = state
    # final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
    #                                  kv_lr, halo_size)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    final_field = mtf.reshape(final_field,
                              [batch_dim, ffx_dim, ffy_dim, ffz_dim])
    return final_field
示例#15
0
def transformer_moe_layer_v1(inputs,
                             output_dim,
                             hparams,
                             train,
                             variable_dtype,
                             layout=None,
                             mesh_shape=None,
                             nonpadding=None,
                             activation=mtf.relu,
                             num_microbatches=None,
                             token_embeddings=None,
                             context=None):
    """Local heterogenous mixture of experts.

  See transformer_moe_layer_v1 in moe.py for a more detailed explanation for
  a generic moe layer.

  The heterogeneous mask outputted by generate_heterogeneous_expert_masks has
  dimension [maximum hidden size, maximum # layers, # experts] and its shape
  will overwrite the parameters moe_num_layers and moe_hidden_size in hparams.
  The layer-specific mask slice is applied at each expert layer to the
  activation which is [expert width, # experts]. If the heterogeneous_mask_info
  is None, there is no mask applied and the code is equivalent to the
  homogeneous case.


  The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting
  of the representations of all positions in a batch of sequences.

  Each position of each sequence is sent to 0-2 experts.  The expert
  choices and the combination weights are determined by a learned gating
  function.

  This function returns a small auxiliary loss that should be added to the
  training loss of the model.  This loss helps to balance expert usage.
  Without the loss, it is very likely that a few experts will be trained and
  the rest will starve.

  Dimensions cheat sheet:
  B: batch dim(s)
  L: original sequence length
  M: input depth
  N: output depth
  G: number of groups
  S: group size
  E: number of experts
  C: expert capacity

  Args:
    inputs: a mtf.Tensor with shape [batch_dim(s), length_dim, input_dim]
    output_dim: a mtf.Dimension (for Transformer, this is input_dim)
    hparams: model hyperparameters
    train: a boolean
    variable_dtype: a mtf.VariableDType
    layout: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
    nonpadding: an optional Tensor with shape [batch_dim(s), length_dim]
      and the same dtype as inputs, consisting of ones(nonpadding)
      and zeros(padding).
    activation: a function.
    num_microbatches: number of microbatches.
    token_embeddings: a mtf.Tensor with shape
      [batch_dim(s), length_dim, input_dim]. These are the word embeddings for
      that correspond to the inputs. These can optionally be used to make
      routing decisions.
    context: a Context.

  Returns:
    outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on unrecognized hparams.moe_gating
  """
    orig_inputs = inputs

    experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

    if hparams.moe_heterogeneous_mask_info is not None:
        tf.logging.info("moe_heterogeneous_mask_info: {}".format(
            hparams.moe_heterogeneous_mask_info))
        heterogeneous_mask = generate_heterogeneous_expert_masks(
            hparams.moe_heterogeneous_mask_info,
            hparams.moe_num_experts,
            experts_dim,
            mesh=inputs.mesh,
            expert_width=hparams.moe_hidden_size)
        # overwrite depth and width with the mask maximum dimension
        hparams.moe_num_layers = heterogeneous_mask.shape[1].size
        hparams.moe_hidden_size = heterogeneous_mask.shape[0].size
    hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)

    # We "cheat" here and look at the mesh shape and layout. This is to ensure
    # that the number of groups is a multiple of the mesh dimension
    # over which those groups are split.
    batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
                                        orig_inputs.shape.dims[-1])
    # Hack: we assume that
    #   "outer_batch" == replication of experts
    #   mesh_dim_size can be derived from mesh_shape and orig_batch_dim
    #
    # We then reqire num_groups to be a multiple of mesh_dim_size.
    if orig_inputs.shape.dims[0].name == "outer_batch":
        outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2]
    else:
        outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1),
                                           orig_inputs.shape.dims[0])

    # Number of MoE inputs (total number of position across batch_and_length_dims
    # per replica.
    n = 1
    for d in batch_and_length_dims:
        n *= d.size

    n = n // outer_batch_dim.size

    mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
                                                    orig_batch_dim)
    num_groups, group_size = moe._split_into_groups(  # pylint: disable=protected-access
        n, hparams.moe_group_size, mesh_dim_size)
    # TODO(barretzoph): implementation without pylint calls?

    group_size_dim = mtf.Dimension("group", group_size)
    num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)

    moe_input_dims = [
        outer_batch_dim, num_groups_dim, group_size_dim, input_dim
    ]
    # OGSM Tensor
    inputs = mtf.reshape(inputs, moe_input_dims)

    # Token embeddings that can be optionally used in the router for determining
    # where to send tokens.
    if hparams.moe_word_embed_mode is not None:
        token_embeddings = mtf.cast(
            mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

    # Each sequence sends expert_capacity positions to each expert.
    if train:
        capacity_factor = hparams.moe_capacity_factor_train
    else:
        capacity_factor = hparams.moe_capacity_factor_eval
    expert_capacity = min(
        group_size_dim.size,
        int((group_size_dim.size * capacity_factor) / experts_dim.size))
    expert_capacity = max(expert_capacity, hparams.moe_min_expert_capacity)
    tf.logging.info("expert_capacity: %d" % expert_capacity)
    expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
    experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
    batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
    if nonpadding is not None:
        nonpadding = mtf.zeros(inputs.mesh,
                               batch_and_length_dims,
                               dtype=inputs.dtype) + nonpadding
        nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
    if hparams.moe_gating == "top_2":
        # combine_tensor,
        # dispatch_tensor  OG`SEC Tensors
        # (G is generally split along mesh dim)
        dispatch_tensor, combine_tensor, loss = moe._top_2_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "top_n":
        dispatch_tensor, combine_tensor, loss = moe._top_n_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "switch":
        dispatch_tensor, combine_tensor, loss = moe._switch_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "ntlb":
        dispatch_tensor, combine_tensor, loss = moe._ntlb_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "switch_max":
        dispatch_tensor, combine_tensor, loss = moe._switch_max_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    elif hparams.moe_gating == "expert_selection":
        dispatch_tensor, combine_tensor, loss = moe._expert_selection_gating(  # pylint: disable=protected-access
            inputs=inputs,
            outer_expert_dims=None,
            experts_dim=experts_dim_unsplit,
            group_size_dim=group_size_dim,
            expert_capacity_dim=expert_capacity_dim,
            hparams=hparams,
            train=train,
            variable_dtype=variable_dtype,
            importance=nonpadding,
            name="expert_selection_gating",
            num_microbatches=num_microbatches,
            token_embeddings=token_embeddings)
    else:
        raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)

    expert_inputs = mtf.einsum([inputs, dispatch_tensor],
                               mtf.Shape([
                                   outer_batch_dim, experts_dim_unsplit,
                                   num_groups_dim, expert_capacity_dim,
                                   input_dim
                               ]))

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over batch -> split over experts
    expert_inputs = mtf.reshape(
        expert_inputs,
        mtf.Shape([
            outer_batch_dim, experts_dim, batch_dim_unsplit,
            expert_capacity_dim, input_dim
        ]))

    # Pretend we have heterogeneous_mask with shape [moe_num_layers, num_experts]
    for layer_idx in range(hparams.moe_num_layers):
        with tf.variable_scope("expert_layer_{}".format(layer_idx)):
            res_h = 0.0
            if layer_idx > 0:
                res_h = expert_inputs
                expert_inputs = transformer.sublayer_rms_norm(
                    expert_inputs, None, context)

            # Now feed the expert inputs through the experts.
            h = mtf.layers.dense_product(
                expert_inputs,
                reduced_dims=expert_inputs.shape.dims[-1:],
                new_dims=[hidden_dim],
                expert_dims=[experts_dim],
                activation_functions=activation,
                use_bias=False,
                variable_dtype=variable_dtype,
                name="wi")

            # apply dropout
            if hparams.moe_dropout_rate != 0.0:
                h = mtf.dropout(h,
                                is_training=train,
                                keep_prob=1.0 - hparams.moe_dropout_rate)
            # only if heterogeneous
            if hparams.moe_heterogeneous_mask_info is not None:
                # Get mask for current layer by slicing heterogeneous mask
                heterogeneous_mask_slice = mtf.slice(heterogeneous_mask,
                                                     layer_idx, 1,
                                                     "num_expert_layers")

                # Get rid of the expert layers dimension.
                heterogeneous_mask_slice = mtf.reshape(
                    heterogeneous_mask_slice, [
                        heterogeneous_mask_slice.shape[0],
                        heterogeneous_mask_slice.shape[-1]
                    ])
                h *= mtf.cast(heterogeneous_mask_slice, h.dtype)
            expert_output = mtf.layers.dense(h,
                                             output_dim,
                                             expert_dims=[experts_dim],
                                             use_bias=False,
                                             reduced_dims=h.shape.dims[-1:],
                                             variable_dtype=variable_dtype,
                                             name="wo")

            if layer_idx < (hparams.moe_num_layers - 1):
                expert_output = transformer.sublayer_dropout(
                    expert_output, None, context)
            expert_output += res_h
            expert_inputs = expert_output

    # Extra reshape reduces communication cost for model-parallel versions.
    # For model-parallel versions, this reshape causes an mtf.slice and for non-
    # model-parallel versions, this has no effect.
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim, experts_dim_unsplit, num_groups_dim,
            expert_capacity_dim, d_model_split_dim
        ]))

    # Split over experts -> split over batch
    expert_output = mtf.reshape(
        expert_output,
        mtf.Shape([
            outer_batch_dim,
            experts_dim_unsplit,
            num_groups_dim,
            expert_capacity_dim,
            output_dim,
        ]))
    moe_output_dims = moe_input_dims[:-1] + [output_dim]
    output = mtf.einsum([expert_output, combine_tensor],
                        mtf.Shape(moe_output_dims))
    output = mtf.reshape(output, batch_and_length_dims + [output_dim])
    return output, loss * hparams.moe_loss_coef
示例#16
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    #
    # Begin simulation

    ## Compute initial initial conditions distributed
    #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    fieldvar = mtf.get_variable(mesh, 'linear', part_shape)
    input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    linearop = mtf.assign(fieldvar, mtfinp)

    #field = fieldvar
    initc = fieldvar

    print("initc : ", initc)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])
    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    print("mmy : ", mmy)
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    width = tf.placeholder(tf.float32, shape=())

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(width * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    #k_dims = [d.shape[0] for d in kv]
    #k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)
    #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype)
    #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype)
    #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    R0 = tf.placeholder(tf.float32, shape=())
    M0 = tf.placeholder(tf.float32, shape=())
    off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder(
        tf.float32, shape=data.shape)
    mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
    mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0)
    #diff = diff / 0.25
    #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one
    diff = (diff + mtfoff) / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype)
    var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
    def _call_internal(self,
                       context,
                       inputs,
                       targets=None,
                       attributes=None,
                       z=None):
        """Compute logits based on inputs (all positions in parallel).
        Also updates context if applicable.
        Args:
          context: a Context
          inputs: a Tensor
          targets: an optional Tensor
          attributes: an optional Tensor
        Returns:g
          logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim]
        """
        mesh = inputs.mesh
        if self.ensemble_dim and self.ensemble_dim not in inputs.shape.dims:
            # Training an ensemble where all models are trained on the same examples.
            inputs = mtf.broadcast(inputs,
                                   [self.ensemble_dim] + inputs.shape.dims)
            if self.ensemble_dim not in attributes.shape.dims:
                attributes = mtf.broadcast(attributes, [self.ensemble_dim] +
                                           attributes.shape.dims)
            if targets:
                targets = mtf.broadcast(targets, [self.ensemble_dim] +
                                        targets.shape.dims)
        if "embedding" in context.shared_params:
            vocab_embedding = context.shared_params["embedding"]
        else:
            vocab_embedding = VocabEmbedding(mesh,
                                             self.input_vocab_dim,
                                             self.model_dim,
                                             context.variable_dtype,
                                             name="embedding",
                                             ensemble_dim=self.ensemble_dim)
        x = vocab_embedding.ids_to_embedding(inputs)
        if self.positional_embedding:
            if "positional_embedding" in context.shared_params:
                pos_emb_var = context.shared_params["positional_embedding"]
            else:
                pos_emb_var = mtf.layers.embedding_weights(
                    mesh,
                    self.max_length_dim,
                    self.model_dim,
                    context.variable_dtype,
                    "positional_embedding",
                    ensemble_dim=self.ensemble_dim)
            if (context.length_dim is not None
                    and context.length_dim.size > self.max_length_dim.size):
                message = (
                    "Length dimenison exceeds size of positional embedding table. "
                    "length_dim.size > max_length_dim.size %s vs %s." %
                    (context.length_dim, self.max_length_dim))
                if context.position_is_default:
                    # Definitely getting overflow in this case.
                    raise ValueError(message)
                else:
                    tf.logging.warning(
                        message +
                        " This may be OK if there are several shorter sequences packed "
                        "together.  Otherwise, the later positions will get zeros."
                    )
            if context.position_is_default:
                pos_emb = mtf.rename_dimension(
                    mtf.slice(pos_emb_var, 0, context.length_dim.size,
                              self.max_length_dim.name),
                    self.max_length_dim.name, context.length_dim.name)
            else:
                pos_emb = mtf.gather(pos_emb_var,
                                     context.position,
                                     self.max_length_dim,
                                     output_shape=x.shape)
            x += pos_emb

        if self.attribute_embedding:
            if "attribute_embedding" in context.shared_params:
                att_emb_var = context.shared_params["attribute_embedding"]
            else:
                att_emb_var = mtf.layers.embedding_weights(
                    mesh,
                    self.attribute_dim,
                    self.model_dim,
                    context.variable_dtype,
                    "attribute_embedding",
                    ensemble_dim=self.ensemble_dim)

            att_emb = mtf.gather(att_emb_var,
                                 attributes,
                                 self.attribute_dim,
                                 output_shape=x.shape)
            # Addition of x and attribute
            # x *= LAMBDA_ATTRIBUTE * sty_emb #

            # Concatenation of x and attribute
            x_attribute = mtf.concat([x, att_emb], self.model_dim.name)
            x = mtf.layers.dense(x_attribute,
                                 self.model_dim,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name="comb_x_attribute")

        if z:
            z = mtf.layers.dense(z,
                                 self.model_dim,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name="z")
            # raise ValueError("x shape=%s , z shape=%s" % (x.shape, z.shape))
            x += z

        x = self.layer_stack.call(context, x)
        if self.output_vocab_dim is None:
            return x
        if self.shared_embedding_and_softmax_weights:
            logits = vocab_embedding.hidden_to_logits(x)
        else:
            logits = mtf.layers.dense(x,
                                      self.output_vocab_dim,
                                      use_bias=False,
                                      variable_dtype=context.variable_dtype,
                                      reduced_dims=x.shape.dims[-1:],
                                      name="logits")
        if targets is not None and context.losses is not None:
            context.losses.append(
                self._compute_loss(context, logits, targets,
                                   self.output_vocab_dim))
        if self.ensemble_dim:
            logits = reduce_ensemble_logits(logits, self.ensemble_dim,
                                            self.output_vocab_dim)
        return logits
示例#18
0
def recon_model(mesh,
                data,
                R0,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))
    print("\nfieldvar : \n", fieldvar)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            fieldvar,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3  #*nc**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.constant(R0)
    print("R0 in the recon_model : ", R0)

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    # Element-wise function that applies a Fourier kernel
    plambda = FLAGS.plambda

    def _cwise_logprob(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        logprob = galmean.log_prob(data)
        return -1 * logprob

    cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype)
    cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype)
    final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype)
    chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata],
                      output_dtype=tf.float32)  #
    chisq = mtf.reduce_sum(chisq)
    ##    #

    loss = chisq + prior

    def _cwise_sample(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        sample = galmean.sample()
        return sample

    sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata],
                       output_dtype=tf.float32)  #
    fields = [fieldvar, sample]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
示例#19
0
def force(state,
          lr_shape,
          hr_shape,
          kvec_lr,
          kvec_hr,
          halo_size,
          cosmology=Planck15,
          downsampling_factor=2,
          pm_nc_factor=1,
          antialias=True,
          **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape
    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_hr = [d.shape[0] for d in kvec_hr]
    # Reorder the FFTs which where transposed# y,z,x
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]]

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)

    # Split the field into low and high resolution
    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)
    low = mtf.reshape(low, low.shape[:-1])
    hr_field = mtf.reshape(high, high.shape[:-1])
    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)
    hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr)

    kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield,
                                                    kvec_lr,
                                                    r_split=0)
    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)
    kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield,
                                                    kvec_hr,
                                                    r_split=0)
    kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]
    kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]]

    displacement = []
    for f, g in zip(kfield_lr, kfield_hr):
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])
        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [
                halo_size // 2**downsampling_factor,
                halo_size // 2**downsampling_factor
            ], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim,
                                     halo_size // 2**downsampling_factor)
        f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)])
        f = mesh_utils.upsample(f, downsampling_factor)
        f = mtf.reshape(f, f.shape[:-1])

        g = mesh_utils.c2r3d(g, f.shape[-3:])
        high_shape = g.shape
        # And now we remove the large scales
        g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)])
        _low = mesh_utils.downsample(g,
                                     downsampling_factor,
                                     antialias=antialias)
        g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor),
                            g.shape)
        g = mtf.reshape(g, high_shape)

        d = mesh_utils.cic_readout(f + g, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F
def gradient_based_subword_tokenization(x,
                                        length_dim,
                                        max_subword_length=4,
                                        downsample=None,
                                        use_offsets=False,
                                        consider_chars_as_blocks=False,
                                        use_block_pos_embedding=False,
                                        share_block_kernel=False,
                                        memory_embeddings=0,
                                        context=None,
                                        block_mixing_mode=None,
                                        activation="softmax",
                                        downsample_function="mean"):
    """Implements GBSWT from Charformer.

  Args:
    x: a Tensor containing length_dim
    length_dim: a Dimension
    max_subword_length: integer
    downsample: integer.
    use_offsets: boolean.
    consider_chars_as_blocks: boolean.
    use_block_pos_embedding: boolean.
    share_block_kernel: boolean.
    memory_embeddings: integer.
    context: Context.
    block_mixing_mode: Str for block mixing.
    activation: Str for block ranking.
    downsample_function: Str, supports mean/linformer for now.

  Returns:
    a Tensor with the same shape as x.

  Raises:
    ValueError: if channels or depth don't match.
  """
    # don't use this for now.
    del max_subword_length
    del memory_embeddings
    all_blocks = []
    all_scores = []
    tf.logging.info("GSW block layer")

    def _tile(x, n, tile_dim):
        # Simple tile function in MTF.
        return mtf.concat([x] * n, tile_dim.name)

    def _repeat(x, n, repeat_dim):
        # repeat function in MTF
        tmp_dim = mtf.Dimension("tmp", 1)
        expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
        x = mtf.reshape(x, expand_shape)
        x = _tile(x, n, tmp_dim)
        output_shape = []
        for dim in x.shape.dims:
            if dim.name == "tmp":
                continue
            if dim.name == repeat_dim.name:
                dim = mtf.Dimension(dim.name, dim.size * n)
            output_shape.append(dim)
        output_shape = mtf.Shape(output_shape)
        x = mtf.reshape(x, output_shape)
        return x

    def _combined_dim(dims):
        return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)

    # compute all subword blocks
    # TODO(yitay): handle offsets to get all blocks
    if activation == "sigtanh":
        # one score for sigmoid
        tmp_dim = mtf.Dimension("block_score", 2)
    else:
        tmp_dim = mtf.Dimension("block_score", 1)

    model_dim = x.shape[-1]
    subword_blocks_width = [2, 3, 4]

    if consider_chars_as_blocks:
        subword_blocks_width += [1]

    if share_block_kernel:
        block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
        block_kernel = mtf.get_variable(x.mesh,
                                        "block_kernel",
                                        block_kernel_shape,
                                        initializer=None,
                                        dtype=context.variable_dtype)
    else:
        block_kernel = None

    for subword_len in subword_blocks_width:
        if use_block_pos_embedding:
            # this is turn off by default. It is meant to support cases like
            # parameterized pooling or other features.
            block_len_dim = mtf.Dimension(length_dim.name, subword_len)
            # TODO(vqtran): Consider other positional embeddings.
            block_pos_emb = sinusoid_positional_embedding_weights(
                context.mesh, block_len_dim, x.shape[-1],
                context.variable_dtype.activation_dtype)
            block_pos_emb = _repeat(
                block_pos_emb, math.ceil(length_dim.size / float(subword_len)),
                block_len_dim)
        if use_offsets:
            offset_space = subword_len
        else:
            offset_space = 1
        for offsets in range(offset_space):
            if offsets > 0:
                xoff = mtf.shift(x, offsets, length_dim, wrap=False)
                if use_block_pos_embedding:
                    block_pos_emb = mtf.shift(block_pos_emb,
                                              offsets,
                                              block_pos_emb.shape[-2],
                                              wrap=False)
            else:
                xoff = x
            tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
            if length_dim.size % subword_len != 0:
                tf.logging.info("Not divisible by length")
                # add extra padding tokens
                pad_amt = int(subword_len) - int(length_dim.size % subword_len)
                kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
            else:
                kp = xoff

            if use_block_pos_embedding:
                kp += block_pos_emb

            bx = mtf.pool_tensor_1d(
                kp,
                pool_dim=kp.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(subword_len))
            block_score = mtf.layers.dense(bx, [tmp_dim],
                                           use_bias=False,
                                           name="bx",
                                           reduced_dims=[model_dim],
                                           variable_dtype=None,
                                           kernel_weights=block_kernel)

            expand_bx = _repeat(bx, subword_len, length_dim)
            expand_scores = _repeat(block_score, subword_len, length_dim)
            if offsets > 0:
                # add offset.
                expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [offsets, 0],
                                        length_dim.name)
            new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
            if new_len.size < length_dim.size:
                pad_amt = new_len.size - length_dim.size
                expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [0, pad_amt],
                                        length_dim.name)
            elif new_len.size > length_dim.size:
                expand_bx = mtf.slice(expand_bx, 0, length_dim.size,
                                      length_dim.name)
                expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
                                          length_dim.name)

            new_tmp_dim = mtf.Dimension("extra_dim", 1)
            expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
            expand_scores_shape = mtf.Shape(expand_scores.shape.dims +
                                            [new_tmp_dim])
            expand_bx = mtf.reshape(expand_bx, expand_shape)
            expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
            all_blocks.append(expand_bx)
            all_scores.append(expand_scores)

    all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
    all_scores = mtf.concat(all_scores, new_tmp_dim.name)
    tf.logging.info(all_blocks)
    new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
    combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
    block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
    block_net = mtf.reshape(all_scores, block_net_shape)

    if block_mixing_mode == "score_attention":
        tf.logging.info("Using score attention")
        att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
        tf.logging.info(block_net)
        att = mtf.softmax(att, reduced_dim=att.shape[-1])
        block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
        tf.logging.info(block_net)

    if activation == "softmax":
        block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
    elif activation == "tanh":
        tf.logging.info("Using tanh")
        block_net = mtf.tanh(block_net)

    all_blocks = block_net * all_blocks
    all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
    output = all_blocks

    if downsample:
        output_length = output.shape.get_dim_by_name("length")
        if output_length.size % int(downsample) != 0:
            pad_amt = int(downsample) - int(
                output_length.size % int(downsample))
            output = mtf.pad(output, [0, pad_amt], output_length.name)
        if downsample_function == "mean":
            output = mtf.pool_tensor_1d(
                output,
                pool_dim=output.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(downsample))
        else:
            raise ValueError("Downsampling function not implemeneted.")

    return output
示例#21
0
def force_single(state,
                 lr_shape,
                 hr_shape,
                 kvec_lr,
                 halo_size,
                 cosmology=Planck15,
                 pm_nc_factor=1,
                 **kwargs):
    """
  Estimate force on the particles given a state.

  Parameters:
  -----------
  state: tensor
    Input state tensor of shape (3, batch_size, npart, 3)

  boxsize: float
    Size of the simulation volume (Mpc/h) TODO: check units

  cosmology: astropy.cosmology
    Cosmology object

  pm_nc_factor: int
    TODO: @modichirag please add doc
  """
    X, P, F = state
    #TODO: support different factor
    assert pm_nc_factor == 1
    lnc = lr_shape[-1].size
    part_shape = X.shape

    # Paint the particles on the high resolution mesh
    field = mtf.zeros(X.mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)
    field = mesh_utils.cic_paint(field, X, halo_size)
    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim,
                                     halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        field = mtf.slice(field, halo_size, block_size_dim.size,
                          block_size_dim.name)

    # Hack usisng  custom reshape because mesh is pretty dumb
    lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                             output_dtype=tf.float32,
                             output_shape=lr_shape,
                             name='my_dumb_reshape',
                             splittable_dims=lr_shape[:-1] + hr_shape[:4])

    k_dims_lr = [d.shape[0] for d in kvec_lr]
    k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]]
    lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr)

    kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr)

    # Reorder the low res FFTs which where transposed# y,z,x
    kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]]

    displacement = []
    for f in kfield_lr:
        f = mesh_utils.c2r3d(f, lr_shape[-3:])
        f = mtf.slicewise(
            lambda x: tf.expand_dims(
                tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1),
            [f],
            output_dtype=tf.float32,
            output_shape=mtf.Shape(hr_shape[0:4] + [
                mtf.Dimension('sx_block', lnc // hr_shape[1].size),
                mtf.Dimension('sy_block', lnc // hr_shape[2].size),
                mtf.Dimension('sz_block', lnc // hr_shape[3].size)
            ]),
            name='my_reshape',
            splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

        for block_size_dim in hr_shape[-3:]:
            f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name)
        for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]):
            f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size)
        d = mesh_utils.cic_readout(f, X, halo_size)
        displacement.append(d)

    # Readout the force to particle positions
    F = mtf.stack([d for d in displacement], "ndim", axis=4)

    F = F * 1.5 * cosmology.Om0
    return X, P, F
示例#22
0
def lpt_prototype(mesh,
                  nc=FLAGS.nc,
                  bs=FLAGS.box_size,
                  batch_size=FLAGS.batch_size,
                  a0=FLAGS.a0,
                  a=FLAGS.af,
                  nsteps=FLAGS.nsteps):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 0
    lnc = nc // 2**downsampling_factor

    #

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype('float32'), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    #    # Reshaping array into high resolution mesh
    #    field = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1),
    #                      [initc],
    #                      output_dtype=tf.float32,
    #                      output_shape=hr_shape,
    #                      name='my_reshape',
    #                      splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3])
    #

    state = mtfpm.lpt_init_single(
        initc,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    # Here we can run our nbody
    final_state = state  #mtfpm.nbody(state, stages, lr_shape, hr_shape, k_dims, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    #final_field = mtf.reshape(final_field,  [batch_dim, fx_dim, fy_dim, fz_dim])
    # Hack usisng  custom reshape because mesh is pretty dumb
    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=tf.float32,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    return initc, final_field
示例#23
0
  def _call_internal(self, context, inputs, targets=None):
    """Compute logits based on inputs (all positions in parallel).

    Also updates context if applicable.

    Args:
      context: a Context
      inputs: a Tensor
      targets: an optional Tensor

    Returns:
      logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim]
    """
    mesh = inputs.mesh
    if "embedding" in context.shared_params:
      embedding_weights = context.shared_params["embedding"]
    else:
      embedding_weights = mtf.layers.embedding_weights(
          mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype,
          name="embedding")
    x = mtf.gather(embedding_weights, inputs, self.input_vocab_dim)
    if "positional_embedding" in context.shared_params:
      pos_emb_var = context.shared_params["positional_embedding"]
    else:
      pos_emb_var = mtf.layers.embedding_weights(
          mesh, self.max_length_dim, self.model_dim, context.variable_dtype,
          "positional_embedding")
    if context.position_is_default:
      pos_emb = mtf.rename_dimension(
          mtf.slice(pos_emb_var, 0, context.length_dim.size,
                    self.max_length_dim.name),
          self.max_length_dim.name, context.length_dim.name)
    else:
      pos_emb = mtf.gather(
          pos_emb_var, context.position, self.max_length_dim,
          output_shape=x.shape)
    x += pos_emb
    x = self.layer_stack.call(context, x)
    if self.output_vocab_dim is None:
      return x
    if self.shared_embedding_and_softmax_weights:
      logits = mtf.einsum(
          [x * (self.model_dim.size ** -0.5), embedding_weights],
          reduced_dims=[self.model_dim])
    else:
      logits = mtf.layers.dense(
          x, self.output_vocab_dim, use_bias=False,
          variable_dtype=context.variable_dtype,
          name="logits")
    if targets is not None and context.losses is not None:
      off_value = self.label_smoothing / self.output_vocab_dim.size
      on_value = 1.0 - self.label_smoothing + off_value
      soft_targets = mtf.one_hot(
          targets, self.output_vocab_dim,
          dtype=context.activation_dtype,
          on_value=on_value,
          off_value=off_value)
      loss = mtf.layers.softmax_cross_entropy_with_logits(
          logits, soft_targets, self.output_vocab_dim,
          z_loss=self.z_loss if context.train else 0.0)
      weights = mtf.layers.weights_nonzero(
          targets, dtype=context.activation_dtype)
      loss = mtf.reduce_mean(loss * weights)
      context.losses.append(loss)
    return logits
示例#24
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 2
    lnc = nc // 2**downsampling_factor

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype(npdtype),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype(npdtype),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype(npdtype),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc],
                                  symmetric=False,
                                  dtype=npdtype)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)

    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False,
                                  dtype=npdtype)
    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype(npdtype),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype(npdtype),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype(npdtype),
                                 shape=[padded_sz_dim])
    kv_hr = [ky_hr, kz_hr, kx_hr]

    # kvec for prior blocks
    prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x)
    prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y)
    prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z)

    kvec_pr = flowpm.kernels.fftk(
        [nc // n_block_x, nc // n_block_y, nc // n_block_z],
        symmetric=False,
        dtype=npdtype)
    kx_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[0].squeeze().astype(npdtype),
                                 shape=[prior_sx_dim])
    ky_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[1].squeeze().astype(npdtype),
                                 shape=[prior_sy_dim])
    kz_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[2].squeeze().astype(npdtype),
                                 shape=[prior_sz_dim])
    kv_pr = [ky_pr, kz_pr, kx_pr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, x_dim, y_dim, z_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    ## Compute initial initial conditions distributed

    fieldvar = mtf.get_variable(mesh, 'linear', hr_shape)
    input_field = tf.placeholder(data.dtype, [
        batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x,
        nc // n_block_y, nc // n_block_z
    ])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape)
    linearop = mtf.assign(fieldvar, mtfinp)
    #
    field = fieldvar
    initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                          output_dtype=tf.float32,
                          output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                          name='my_dumb_reshape',
                          splittable_dims=part_shape[:-1] + hr_shape[:4])

    #
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=initc.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init(low,
                               high,
                               0.1,
                               kv_lr,
                               kv_hr,
                               halo_size,
                               hr_shape,
                               lr_shape,
                               part_shape[1:],
                               downsampling_factor=downsampling_factor,
                               antialias=True)

        final_state = mtfpm.nbody(state,
                                  stages,
                                  lr_shape,
                                  hr_shape,
                                  kv_lr,
                                  kv_hr,
                                  halo_size,
                                  downsampling_factor=downsampling_factor)
    else:
        final_state = mtfpm.lpt_init(low,
                                     high,
                                     stages[-1],
                                     kv_lr,
                                     kv_hr,
                                     halo_size,
                                     hr_shape,
                                     lr_shape,
                                     part_shape[1:],
                                     downsampling_factor=downsampling_factor,
                                     antialias=True)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv_pr]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.placeholder(tf.float32, shape=())

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype)
    var_grads = [
        mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype)
    ]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0