Beispiel #1
0
def benchmark_model(mesh):
  """
  Initializes a 3D volume with random noise, and execute a forward FFT
  """
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  x_dim = mtf.Dimension("nx", FLAGS.cube_size)
  y_dim = mtf.Dimension("ny", FLAGS.cube_size)
  z_dim = mtf.Dimension("nz", FLAGS.cube_size)

  tx_dim = mtf.Dimension("tnx", FLAGS.cube_size)
  ty_dim = mtf.Dimension("tny", FLAGS.cube_size)
  tz_dim = mtf.Dimension("tnz", FLAGS.cube_size)

  # Create field
  field = mtf.random_normal(mesh, [batch_dim, x_dim, y_dim, z_dim])

  input_field = field
  field = mtf.cast(field, tf.complex64)
  err = 0
  # Performs several back and forth FFTs in the same session
  for i in range(FLAGS.n_ffts):
    # Apply FFT
    fft_field = mpm.fft3d(field, [tx_dim, ty_dim, tz_dim])
    # Inverse FFT
    field = mpm.ifft3d(fft_field * 1, [x_dim, y_dim, z_dim])
    err += mtf.reduce_max(mtf.abs(mtf.cast(field, tf.float32) - input_field))

  field = mtf.cast(field, tf.float32)
  # Compute errors
  err += mtf.reduce_max(mtf.abs(field - input_field))
  return err
Beispiel #2
0
def linear_field(mesh, shape, boxsize, nc, pk, kvec,
                 seed=None, dtype=tf.float32):
  """Generates a linear field with a given linear power spectrum, in a
  distributed fashion
  """
  # Element-wise function that applies a Fourier kernel
  def _cwise_fn(kfield, pk, kx, ky, kz):
      kx = tf.reshape(kx, [-1, 1, 1])
      ky = tf.reshape(ky, [1, -1, 1])
      kz = tf.reshape(kz, [1, 1, -1])
      kk = tf.sqrt((kx / boxsize * nc)**2 + (ky/ boxsize * nc)**2 + (kz/ boxsize * nc)**2)
      shape = kk.shape
      kk = tf.reshape(kk, [-1])
      pkmesh = tfp.math.interp_regular_1d_grid(x=kk, x_ref_min=1e-05, x_ref_max=1000.0,
                                               y_ref=pk, grid_regularizing_transform=tf.log)
      pkmesh = tf.reshape(pkmesh, shape)
      kfield = kfield * tf.cast((pkmesh/boxsize**3)**0.5, tf.complex64)
      return kfield

  k_dims = [d.shape[0] for d in kvec]
  k_dims = [k_dims[2], k_dims[0], k_dims[1]]

  # Generates the random field
  field = mtf.random_normal(mesh, shape=shape,
                                 mean=0, stddev=nc**1.5, dtype=tf.float32)

  # Apply power spectrum on both grids
  cfield = mesh_utils.r2c3d(field, k_dims)
  cfield = mtf.cwise(_cwise_fn, [cfield, pk] + kvec, output_dtype=tf.complex64)
  field = mesh_utils.c2r3d(cfield, field.shape[-3:])
  return field
    def hidden_to_logits(self, hidden: mtf.Tensor,
                         context: transformer.Context) -> mtf.Tensor:
        """Function called by mtf transformer to get the logits.

    Args:
      hidden: an mtf.Tensor, hidden model states of the final decoder layer.
      context: a transformer.Context, the context used for the call to the
        transformer.

    Returns:
      An mtf.Tensor, the logits.
    """
        hidden *= self._output_dim.size**-0.5

        component_contexts = mtf.einsum([
            mtf.rename_dimension(hidden, self._output_dim.name,
                                 self._copy_output_dim.name),
            self._context_weights,
        ],
                                        reduced_dims=[self._copy_output_dim])
        component_contexts = mtf.tanh(component_contexts +
                                      self._context_weights_bias)
        component_logits = mtf.einsum(
            [component_contexts, self._embedding_weights],
            reduced_dims=[self._output_dim])
        component_logits = self._dropout(component_logits, context)

        prior_tanh = mtf.tanh(
            mtf.einsum([self._prior_weights, hidden],
                       reduced_dims=[self._output_dim]) +
            self._prior_weights_bias)
        prior_tanh = self._dropout(prior_tanh, context)
        prior_shared_logits = mtf.einsum([self._prior_gates_vector, hidden],
                                         reduced_dims=[self._output_dim])
        prior_frequent_vocab_logits = (
            mtf.einsum([self._prior_vocab_vector, prior_tanh]) +
            prior_shared_logits + self._prior_bias)
        prior_logits = mtf.concat([
            prior_frequent_vocab_logits,
            mtf.ones(self._mesh,
                     mtf.Shape([self._rare_vocab_dim]),
                     dtype=prior_shared_logits.dtype) * prior_shared_logits
        ], self._vocab_dim.name)
        if context.train and self._noise_std_dev != 0.0:
            prior_logits += mtf.random_normal(self._mesh,
                                              prior_logits.shape,
                                              stddev=self._noise_std_dev)
        prior_proportions = self._sigmoid_tree(prior_logits)

        logits = mtf.einsum([component_logits, prior_proportions],
                            reduced_dims=[self._gates_dim])
        return self._rearrange_sentinels(logits)