Ejemplo n.º 1
0
  def create_positional_emb_2d(self, targets):
    """Learned 2d positional embedding for images."""
    mesh = targets.mesh

    positional_emb_rows_var = mtf.get_variable(
        mesh, "positional_emb_rows",
        mtf.Shape([self.pos_dim, self.model_dim]),
        initializer=tf.random_normal_initializer(),
        activation_dtype=self.activation_type)
    positional_emb_cols_var = mtf.get_variable(
        mesh, "positional_emb_cols",
        mtf.Shape([self.pos_dim, self.model_dim]),
        initializer=tf.random_normal_initializer(),
        activation_dtype=self.activation_type)

    targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32)
    targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32)
    position_x = mtf.broadcast(
        mtf.gather(positional_emb_rows_var, targets_position_x,
                   self.pos_dim),
        mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))

    position_y = mtf.broadcast(
        mtf.gather(positional_emb_cols_var, targets_position_y,
                   self.pos_dim),
        mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))
    return position_x + position_y
Ejemplo n.º 2
0
 def _embedding_and_softmax_vars(self, mesh):
   hparams = self._hparams
   if hparams.transformer_type == "encoder":
     targets_embedding_var = None
   else:
     targets_embedding_var = mtf.get_variable(
         mesh, "targets_embedding",
         mtf.Shape([self.targets_vocab_dim, self.model_dim]),
         initializer=tf.random_normal_initializer(),
         master_dtype=self.master_dtype,
         slice_dtype=self.slice_dtype,
         activation_dtype=self.activation_dtype)
   if hparams.transformer_type == "decoder":
     inputs_embedding_var = None
   else:
     if hparams.shared_embedding and targets_embedding_var:
       inputs_embedding_var = targets_embedding_var
     else:
       inputs_embedding_var = mtf.get_variable(
           mesh, "inputs_embedding",
           mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
           initializer=tf.random_normal_initializer(),
           master_dtype=self.master_dtype,
           slice_dtype=self.slice_dtype,
           activation_dtype=self.activation_dtype)
   if hparams.shared_embedding_and_softmax_weights:
     softmax_var = (targets_embedding_var or inputs_embedding_var) * (
         self.model_dim.size ** -0.5)
   else:
     softmax_var = mtf.get_variable(
         mesh,
         "softmax",
         mtf.Shape([self.targets_vocab_dim, self.model_dim]),
         initializer=tf.random_normal_initializer(
             stddev=self.model_dim.size**-0.5),
         master_dtype=self.master_dtype,
         slice_dtype=self.slice_dtype,
         activation_dtype=self.activation_dtype)
   positional_embedding_var = mtf.get_variable(
       mesh, "positional_embedding",
       mtf.Shape([self.max_length_dim, self.model_dim]),
       initializer=tf.random_normal_initializer(),
       activation_dtype=self.activation_dtype)
   return (inputs_embedding_var, targets_embedding_var,
           softmax_var, positional_embedding_var)
Ejemplo n.º 3
0
def synthetic_attention(q,
                        k,
                        v,
                        memory_length_dim,
                        key_dim,
                        value_dim,
                        bias=None,
                        dropout_rate=0.0,
                        dropout_broadcast_dims=None,
                        extra_logit=None,
                        synthesize=True,
                        synthesize_mode="random_plus_alpha",
                        factorized_dim=16,
                        max_length=512,
                        context=None):
  """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743).

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor
    synthesize: flag to use synthetic attention or not
    synthesize_mode: which variant of synthesizer to use
    factorized_dim: factorized dim for synthesizers
    max_length: max length of input sequence
    context: context since we need context mode

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """

  if synthesize:
    num_heads = v.shape.get_dim_by_name("heads")
    tf.logging.info("Using synthesizer")
    if synthesize_mode == "random":
      tf.logging.info("Using Random Synthesizers")
      r_shape = mtf.Shape([mtf.Dimension("length", max_length),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", max_length)])
      r = mtf.get_variable(context.mesh, "R", r_shape,
                           initializer=None,
                           dtype=context.variable_dtype)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      logits = r
      r_shape = logits.shape
    elif synthesize_mode == "factorized":
      tf.logging.info("Using Factorized Random Synthesizers")
      k = factorized_dim
      r1_shape = mtf.Shape([mtf.Dimension("tmp", k),
                            mtf.Dimension("heads", num_heads.size),
                            mtf.Dimension("memory_length", 512)])
      r2_shape = mtf.Shape([mtf.Dimension("tmp", k),
                            mtf.Dimension("heads", num_heads.size),
                            mtf.Dimension("memory_length", 512)])
      r_shape = mtf.Shape([mtf.Dimension("length", 512),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", 512)])
      r1 = mtf.get_variable(context.mesh, "R1", r1_shape,
                            initializer=None,
                            dtype=context.variable_dtype)
      r2 = mtf.get_variable(context.mesh, "R2", r2_shape,
                            initializer=None,
                            dtype=context.variable_dtype)
      r = mtf.einsum([r1, r2], r_shape)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      logits = r
    elif synthesize_mode == "dense_minus":
      # Dense Synthesizer Model
      tmp_dim = mtf.Dimension("memory_length", max_length)
      logits = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                                use_bias=False,
                                name="pi",
                                reduced_dims=[key_dim],
                                variable_dtype=None)
      logits = mtf.slice(logits, 0, memory_length_dim.size,
                         memory_length_dim.name)
      if context.mode == "incremental":
        pass
      else:
        length_dim = q.shape.get_dim_by_name("length")
        logits = mtf.slice(logits, 0, length_dim.size, "length")
    elif synthesize_mode == "random_plus_alpha" or \
        synthesize_mode == "random_plus":
      # Mixture Random Synthesizer with learnable Alpha
      tf.logging.info("Using Random Plus Alpha")
      logits = mtf.einsum([q, k], reduced_dims=[key_dim])
      num_heads = logits.shape.get_dim_by_name("heads")
      r_shape = mtf.Shape([mtf.Dimension("length", 512),
                           mtf.Dimension("heads", num_heads.size),
                           mtf.Dimension("memory_length", 512)])
      r = mtf.get_variable(context.mesh, "R", r_shape,
                           initializer=None,
                           dtype=context.variable_dtype)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length"))
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, length_dim.name)
      if "alpha" in synthesize_mode:
        alpha = mtf.get_variable(context.mesh,
                                 "alpha",
                                 mtf.Shape([mtf.Dimension("alpha", 1)]),
                                 initializer=tf.zeros_initializer(),
                                 dtype=context.variable_dtype)
        alpha = mtf.sigmoid(alpha)
        logits = ((1-alpha) * logits) + (alpha * r)
      else:
        logits = logits + r
    elif synthesize_mode == "dense_plus_alpha" or \
        synthesize_mode == "dense_plus":
      # Mixture Dense Synthesizer with learnable alpha
      tf.logging.info("Using Dense Plus Alpha Scaling")
      logits = mtf.einsum([q, k], reduced_dims=[key_dim])
      tmp_dim = mtf.Dimension("memory_length", 512)
      r = mtf.layers.dense(mtf.relu(q), [tmp_dim],
                           use_bias=False,
                           name="pi",
                           reduced_dims=[key_dim],
                           variable_dtype=None)
      r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name)
      if context.mode == "incremental":
        pass
      else:
        length_dim = q.shape.get_dim_by_name("length")
        r = mtf.slice(r, 0, length_dim.size, "length")
      if "alpha" in synthesize_mode:
        alpha = mtf.get_variable(context.mesh,
                                 "alpha",
                                 mtf.Shape([mtf.Dimension("alpha", 1)]),
                                 initializer=tf.zeros_initializer(),
                                 dtype=context.variable_dtype)
        alpha = mtf.sigmoid(alpha)
        logits = ((1-alpha) * logits) + (alpha * r)
      else:
        logits = logits + r
  if bias is not None:
    logits += bias

  weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
  weights = mtf.dropout(
      weights, context.train, 1.0 - dropout_rate,
      noise_shape=weights.shape - dropout_broadcast_dims)

  if synthesize and "plus" not in synthesize_mode:
    if synthesize_mode == "dense_minus":
      outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim])
    else:
      outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim])
  else:
    outputs_shape = q.shape - [key_dim] + value_dim

  outputs = mtf.einsum([weights, v], outputs_shape)
  return outputs
Ejemplo n.º 4
0
    def attention(self,
                  x,
                  n_state,
                  mask,
                  attention_type="global",
                  name="attn"):
        # x :: [batch, seq, n_embd]
        batch_dim, seq_dim, embd_dim = x_shape = x.shape
        assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads"
        with tf.variable_scope(name):
            # Compute attention inputs
            mtfparams = mtf.transformer.attention.attention_params_simple(
                x.mesh,
                io_dim=self.dimensions["embed_dim"],
                kv_dim=self.dimensions["kv_dim"],
                heads_dim=self.dimensions["heads_dim"],
                variable_dtype=self.variable_dtype)
            q = mtfparams.compute_q(x)
            k = mtfparams.compute_k(x)
            v = mtfparams.compute_v(x)

            if self.is_incremental_inference:
                one_hot = mtf.one_hot(self.context.position - 1,
                                      seq_dim,
                                      dtype=self.variable_dtype.master_dtype)
                inv_one_hot = 1.0 - one_hot
                old_k, old_v = self.context.get_states(2)
                k = old_k * inv_one_hot + k * one_hot
                v = old_v * inv_one_hot + v * one_hot

            if exists(self.context):
                self.context.record_new_states([k, v])

            with tf.variable_scope("attention"):
                if attention_type == "local":
                    # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                    radius = self.params.get("local_attention_radius", 256)
                    if self.is_incremental_inference:
                        q *= one_hot
                    a = mtf_transformer.attention.local_attention_1d(
                        q,
                        k,
                        v,
                        length_dim=k.shape[1],
                        key_dim=self.dimensions["kv_dim"],
                        value_dim=self.dimensions["kv_dim"],
                        radius=radius,
                        length_dim_num_splits=1,
                        fully_autoregressive=True,
                        attention_kwargs={},
                    )
                    if self.is_incremental_inference:
                        a = mtf.gather(a, self.context.position - 1, seq_dim)

                elif attention_type == "global":
                    if exists(mask):
                        if not self.is_incremental_inference:
                            broadcasted_mask = mtf.broadcast(
                                mask, [
                                    batch_dim, self.dimensions["heads_dim"],
                                    mask.shape[-2], mask.shape[-1]
                                ])  # TODO: not sure this is correct
                        else:
                            # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                            mask = mtf.gather(mask, self.context.position - 1,
                                              seq_dim)
                            broadcasted_mask = mtf.broadcast(
                                mask, [
                                    batch_dim, self.dimensions["heads_dim"],
                                    mask.shape[-1]
                                ])

                    k = mtf.replace_dimensions(
                        k, k.shape[1], self.dimensions["memory_len_dim"])
                    v = mtf.replace_dimensions(
                        v, v.shape[1], self.dimensions["memory_len_dim"])

                    attn_dropout_rate = self.params.get(
                        "attention_dropout", 0) if self.mode == "train" else 0

                    a = mtf_transformer.attention.attention(
                        q,
                        k,
                        v,
                        memory_length_dim=self.dimensions["memory_len_dim"],
                        key_dim=self.dimensions["kv_dim"],
                        value_dim=self.dimensions["kv_dim"],
                        bias=broadcasted_mask,
                        dropout_rate=attn_dropout_rate)
                else:
                    raise NotImplementedError(
                        "Unknown attention type {}!".format(attention_type))

            with tf.variable_scope("compute_output"):
                a = mtfparams.compute_output(a, x_shape)

            with tf.variable_scope("compute_output_bias"):
                b = mtf.get_variable(
                    x.mesh,
                    "o_b", [embd_dim],
                    initializer=tf.constant_initializer(0),
                    master_dtype=self.variable_dtype.master_dtype,
                    slice_dtype=self.variable_dtype.slice_dtype,
                    activation_dtype=self.variable_dtype.activation_dtype)
                a += b
            residual_dropout = self.params.get("residual_dropout", 0)
            if self.mode == "train" and residual_dropout > 0:
                a = mtf.dropout(a, rate=residual_dropout, name="res_dropout")
            return a
Ejemplo n.º 5
0
    def compute_bias(self, context, memory_position, x):
        """Compute attention bias.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
      x: a Tensor - the query antecedent - required for relative attention
    Returns:
      a Tensor or None
    """
        min_relative_position = self.min_relative_position(context)
        max_relative_position = self.max_relative_position(context)
        # we can often cache the result of this function between similar layers
        can_cache = (self.relative_attention_type is None
                     or self.relative_attention_type == "bias_shared")
        if can_cache:
            cache_key = ("self_attention_mask", min_relative_position,
                         max_relative_position, self.relative_attention_type,
                         self.num_heads)
            if cache_key in context.cache:
                return context.cache[cache_key]
        biases = []
        relative_position = memory_position - context.position
        if min_relative_position is not None:
            visible = mtf.greater_equal(relative_position,
                                        min_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if max_relative_position is not None:
            visible = mtf.less_equal(relative_position, max_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if context.read_priority is not None:
            visible = mtf.greater_equal(
                context.read_priority,
                mtf.layers.rename_length_to_memory_length(
                    context.write_priority))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))

        sequence_id = None
        # Subsequence id should only be set if we are in the decoder and have
        # multiple targets per input. This will allow each sub-target to only attend
        # to itself.
        if isinstance(context.subsequence_id, mtf.Tensor):
            sequence_id = context.subsequence_id
        elif isinstance(context.sequence_id, mtf.Tensor):
            sequence_id = context.sequence_id
        if (sequence_id is not None
                and context.length_dim in sequence_id.shape):
            visible = mtf.equal(
                sequence_id,
                self.rename_length_to_memory_length(sequence_id, context))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if self.relative_attention_type is not None:
            buckets_dim = mtf.Dimension("buckets",
                                        self.relative_attention_num_buckets)
            heads_dim = mtf.Dimension("heads", self.num_heads)
            bidirectional = not context.model.fully_autoregressive
            rp_bucket = _relative_position_bucket(relative_position,
                                                  bidirectional=bidirectional,
                                                  num_buckets=buckets_dim.size)
            if (self.relative_attention_type == "bias"
                    or self.relative_attention_type == "bias_shared"):
                bias_shape = [heads_dim, buckets_dim]
                if context.model.ensemble_dim:
                    bias_shape = [context.model.ensemble_dim] + bias_shape
                values = mtf.get_variable(context.mesh,
                                          "relative_attention_bias",
                                          bias_shape,
                                          dtype=context.variable_dtype)
            elif self.relative_attention_type == "contextual":
                if context.model.ensemble_dim:
                    expert_dims = [context.model.ensemble_dim]
                else:
                    expert_dims = None
                values = layers.dense(x, [buckets_dim, heads_dim],
                                      variable_dtype=context.variable_dtype,
                                      name="relative_attention_contextual",
                                      expert_dims=expert_dims)
            else:
                raise ValueError(
                    "unrecognized relative_attention_type \"%s\"" %
                    self.relative_attention_type)
            biases.append(mtf.gather(values, rp_bucket, buckets_dim))
        ret = mtf.add_n(biases) if biases else None
        if can_cache:
            context.cache[cache_key] = ret
        return ret
Ejemplo n.º 6
0
  def __init__(self,
               mesh: mtf.Mesh,
               vocab_dim: mtf.Dimension,
               output_dim: mtf.Dimension,
               variable_dtype: mtf.VariableDType,
               name: str,
               ensemble_dim: mtf.Dimension,
               extra_ids: int = 0,
               dropout_rate: float = 0.0,
               gate_embedding_size: int = gin.REQUIRED,
               frequent_token_fraction: float = 0.1,
               noise_std_dev: float = 0.0):
    """Configurable embedding for the vocabulary.

    Most of the arguments get passed to `mtf.layers.embedding_weights`.

    Mixtape shares gates for low frequency tokens to improve efficiency. Since
    our vocabs are sorted in decreasing order of frequency with sentinels
    appended to the end, we need to do a little trick to ensure that the
    sentinels are treated as high frequency. If you want to treat the sentinels
    as low frequency tokens, then pass in zero for `extra_ids`.

    Args:
      mesh: the mesh used to layout the tensors.
      vocab_dim: the dimension corresponding to vocabulary.
      output_dim: the dimension corresponding to the model hidden states.
      variable_dtype: the datatype information for the  variables used in the
        embedding tensors.
      name: a name to base variable names off of.
      ensemble_dim: the dimension used for ensembling. Absolutely no guarantees
        that this code will work with ensembling.
      extra_ids: a non-negative integer, the number of sentinels at the end of
        the vocab.
      dropout_rate: a float between 0 and 1, the rate to use for dropout.
      gate_embedding_size: a positive integer, the size to use for embedding for
        the gates. It is usually chosen to be much smaller than d_model.
      frequent_token_fraction: a float between 0 and 1, what fraction of tokens
        to consider as high frequency and not share gates for.
      noise_std_dev: a non-negative float, the standard deviation of the
        Gaussian noise to add to the pre-activation priors.
    """
    self._extra_ids = extra_ids
    self._dropout_rate = dropout_rate
    self._noise_std_dev = noise_std_dev
    self._mesh = mesh
    self._vocab_dim = vocab_dim
    self._frequent_vocab_dim = mtf.Dimension(
        vocab_dim.name, int(frequent_token_fraction * vocab_dim.size))
    self._rare_vocab_dim = mtf.Dimension(
        vocab_dim.name, vocab_dim.size - self._frequent_vocab_dim.size)
    self._output_dim = output_dim
    self._copy_output_dim = mtf.Dimension("_{}_copy".format(output_dim.name),
                                          output_dim.size)
    self._pre_gates_dim = mtf.Dimension("gates", 3)
    self._gates_dim = mtf.Dimension("gates", 4)
    self._gate_embedding_dim = mtf.Dimension("gate_embedding",
                                             gate_embedding_size)

    self._embedding_weights = mtf.layers.embedding_weights(
        mesh=mesh,
        vocab_dim=vocab_dim,
        output_dim=output_dim,
        variable_dtype=variable_dtype,
        name="{}_embedding_weights".format(name),
        ensemble_dim=ensemble_dim)
    ensemble_dims = [ensemble_dim] if ensemble_dim else []
    self._context_weights = mtf.layers.embedding_weights(
        mesh=mesh,
        vocab_dim=self._copy_output_dim,
        output_dim=output_dim,
        variable_dtype=variable_dtype,
        name="{}_context_weights".format(name),
        ensemble_dim=ensemble_dims + [self._gates_dim])
    self._context_weights_bias = mtf.get_variable(
        mesh,
        name="{}_context_weights_bias".format(name),
        shape=mtf.Shape(ensemble_dims + [self._gates_dim, output_dim]),
        dtype=variable_dtype,
        initializer=tf.zeros_initializer())

    self._prior_weights = mtf.layers.embedding_weights(
        mesh=mesh,
        vocab_dim=self._gate_embedding_dim,
        output_dim=output_dim,
        variable_dtype=variable_dtype,
        name="{}_prior_weights".format(name),
        ensemble_dim=ensemble_dims + [self._pre_gates_dim])
    self._prior_weights_bias = mtf.get_variable(
        mesh,
        name="{}_prior_weights_bias".format(name),
        shape=mtf.Shape(ensemble_dims +
                        [self._pre_gates_dim, self._gate_embedding_dim]),
        dtype=variable_dtype,
        initializer=tf.zeros_initializer())
    self._prior_vocab_vector = mtf.get_variable(
        mesh,
        name="{}_prior_vocab_vector".format(name),
        shape=mtf.Shape(ensemble_dims +
                        [self._frequent_vocab_dim, self._gate_embedding_dim]),
        dtype=variable_dtype,
        initializer=tf.random_normal_initializer())
    self._prior_gates_vector = mtf.get_variable(
        mesh,
        name="{}_prior_gates_vector".format(name),
        shape=mtf.Shape(ensemble_dims + [self._pre_gates_dim, output_dim]),
        dtype=variable_dtype,
        initializer=tf.random_normal_initializer())
    self._prior_bias = mtf.get_variable(
        mesh,
        name="{}_prior_bias".format(name),
        shape=mtf.Shape(ensemble_dims +
                        [self._frequent_vocab_dim, self._pre_gates_dim]),
        dtype=variable_dtype,
        initializer=tf.random_normal_initializer())
Ejemplo n.º 7
0
 def add_var(self, name, *shape):
     return mtf.get_variable(
         name,
         shape=shape,
         initializer=create_initializer(),
     )
Ejemplo n.º 8
0
    def _layer_stack(self,
                     x,
                     layers,
                     encoder_output=None,
                     self_attention_mask=None,
                     encdec_attention_mask=None,
                     losses=None,
                     step_num=None,
                     encdec_tensors=None,
                     self_attention_k=None,
                     self_attention_v=None):
        """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      self_attention_k: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels] (incremental mode)
      self_attention_v: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels] (incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
        hparams = self._hparams
        is_incremental = (step_num is not None)

        def layer_prepostprocess_dropout(x):
            if is_incremental:
                return x
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

        num_layers = len(layers)
        num_layer_norms = num_layers + 1
        layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
        layer_norm_combined_var = mtf.get_variable(
            x.mesh,
            "layer_norm_scale",
            mtf.Shape([layer_norms_dim, self.model_dim]),
            initializer=tf.ones_initializer(),
            activation_dtype=x.dtype)
        layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)

        def normalize(x):
            scale = layer_norm_vars.pop(0)
            variance = mtf.reduce_mean(mtf.square(x),
                                       reduced_dim=self.model_dim)
            return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

        if is_incremental:
            new_self_attention_k = []
            new_self_attention_v = []

        for lnum, layer_type in enumerate(layers):
            with tf.variable_scope("%s_%d" % (layer_type, lnum)):
                if layer_type == "att":
                    # Self attention layer
                    if is_incremental:
                        self_att_num = len(new_self_attention_k)
                        y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                            normalize(x),
                            prev_k=self_attention_k[self_att_num],
                            prev_v=self_attention_v[self_att_num],
                            step_num=step_num,
                            master_dtype=self.master_dtype,
                            slice_dtype=self.slice_dtype,
                            name="att")
                        new_self_attention_k.append(new_k)
                        new_self_attention_v.append(new_v)
                        x += y
                    else:
                        x += layer_prepostprocess_dropout(
                            mtf.layers.multihead_attention(
                                normalize(x),
                                None,
                                self_attention_mask,
                                self.kv_dim,
                                self.heads_dim,
                                dropout=hparams.attention_dropout,
                                dropout_broadcast_dims=[self.length_dim],
                                master_dtype=self.master_dtype,
                                slice_dtype=self.slice_dtype,
                                name="att"))
                elif layer_type == "enc_att":
                    # Encoder-Decoder attention layer
                    if is_incremental:
                        # Encoder-Decoder attention layer
                        q_var, o_var, k, v = encdec_tensors[lnum]
                        x += mtf.layers.multihead_encdec_attention_incremental(
                            normalize(x),
                            q_var,
                            o_var,
                            k,
                            v,
                            encdec_attention_mask,
                            name="enc_att")
                    else:
                        x += layer_prepostprocess_dropout(
                            mtf.layers.multihead_attention(
                                normalize(x),
                                encoder_output,
                                encdec_attention_mask,
                                self.kv_dim,
                                self.heads_dim,
                                dropout=hparams.attention_dropout,
                                dropout_broadcast_dims=[self.length_dim],
                                master_dtype=self.master_dtype,
                                slice_dtype=self.slice_dtype,
                                name="enc_att"))
                else:
                    if is_incremental:
                        # insert length dimension.
                        x_shape = x.shape
                        shape_with_length = mtf.Shape(
                            x_shape.dims[:-1] + [mtf.Dimension("length", 1)] +
                            x_shape.dims[-1:])
                        x = mtf.reshape(x, shape_with_length)
                    # ffn layer
                    x += layer_prepostprocess_dropout(
                        self._feedforward_layer(normalize(x),
                                                layer_type,
                                                losses=losses))
                    if is_incremental:
                        # remove length dimension
                        x = mtf.reshape(x, x_shape)

        x = layer_prepostprocess_dropout(normalize(x))
        assert not layer_norm_vars
        if is_incremental:
            return x, new_self_attention_k, new_self_attention_v
        else:
            return x
Ejemplo n.º 9
0
    def __init__(self,
                 config,
                 is_training,
                 input_ids,
                 input_mask=None,
                 token_type_ids=None,
                 scope=None,
                 mesh_shape="",
                 layout=""):
        self.config = copy.deepcopy(config)
        del config
        if not is_training:
            self.config.layer_output_dropout_prob = 0.0
            self.config.attention_probs_dropout_prob = 0.0
            self.config.feedforward_intermediate_dropout_prob = 0.0
        input_shape = input_ids.shape
        assert input_shape.ndims == 2

        self._seq_dim = input_shape.dims[1]
        self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size)
        self._extra_losses = []
        mesh = input_ids.mesh

        if token_type_ids is None:
            token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32)

        with tf.variable_scope(scope, default_name="bert"):
            with tf.variable_scope("embeddings"):
                # Perform embedding lookup on the word ids.
                self.embedding_table = mtf.get_variable(
                    mesh,
                    "word_embeddings",
                    mtf.Shape([self.vocab_dim, self.model_dim]),
                    initializer=self.embedding_initializer)
                self.word_embedding_output = mtf.gather(
                    self.embedding_table, input_ids, self.vocab_dim)

                # Add positional embeddings and token type embeddings, then layer
                # normalize and perform dropout.
                self.embedding_output = self.word_embedding_output

                token_type_table = mtf.get_variable(
                    mesh,
                    "token_type_embeddings",
                    mtf.Shape([self.token_type_vocab_dim, self.model_dim]),
                    initializer=self.embedding_initializer)
                if token_type_ids is not None:
                    self.embedding_output += mtf.gather(
                        token_type_table, token_type_ids,
                        self.token_type_vocab_dim)
                if self.config.position_signal == "embedding":
                    full_position_table = mtf.get_variable(
                        mesh,
                        "position_embeddings",
                        mtf.Shape(
                            [self.max_position_embeddings_dim,
                             self.model_dim]),
                        initializer=self.embedding_initializer)
                    short_position_table = mtf.rename_dimension(
                        mtf.slice(full_position_table, 0, self.seq_dim.size,
                                  self.max_position_embeddings_dim.name),
                        self.max_position_embeddings_dim.name,
                        self.seq_dim.name)
                    self.embedding_output += short_position_table
                self.embedding_output = self.normalize(self.embedding_output)
                self.embedding_output = mtf.dropout(
                    self.embedding_output,
                    is_training,
                    keep_prob=1.0 - self.config.layer_output_dropout_prob)

            with tf.variable_scope("encoder"):
                attention_biases = []
                if input_mask:
                    # [batch_dim, memory_seq_dim]
                    attention_biases.append((1.0 - mtf.to_float(
                        mtf.replace_dimensions(input_mask, self.seq_dim,
                                               self.memory_seq_dim))) *
                                            -10000.0)
                if self.config.position_signal == "relative_attention_bias":
                    buckets_dim = mtf.Dimension("buckets", 32)
                    rp_bucket = _relative_position_bucket(
                        mtf.range(mesh, self.memory_seq_dim, tf.int32) -
                        mtf.range(mesh, self.seq_dim, tf.int32),
                        num_buckets=buckets_dim.size)
                    bias_var = mtf.get_variable(
                        mesh,
                        "relative_attention_bias",
                        [self.num_heads_dim, buckets_dim],
                        initializer=tf.zeros_initializer())
                    attention_biases.append(
                        mtf.gather(bias_var, rp_bucket, buckets_dim))
                attention_bias = mtf.add_n(attention_biases)
                prev_layer_output = self.embedding_output
                self.all_encoder_layers = []
                for block_num in range(self.config.num_blocks):
                    with tf.variable_scope("block_%d" % block_num):
                        for layer_idx, layer_type in enumerate(
                                self.config.block_layers):
                            layer_name = layer_type
                            count = self.config.block_layers[:layer_idx].count(
                                layer_type)
                            if count:
                                layer_name += "_%d" % count
                            with tf.variable_scope(layer_name):
                                x = prev_layer_output
                                if self.config.residual_structure == "direct":
                                    x = self.normalize(x)
                                if layer_type == "attention":
                                    x = self.self_attention(x, attention_bias)
                                elif layer_type == "feedforward":
                                    x = self.feedforward(x)
                                elif layer_type == "moe":
                                    x = self.moe(x, layout, mesh_shape,
                                                 input_mask, is_training)
                                else:
                                    raise ValueError("unknown layer type " +
                                                     layer_type)
                                x = mtf.dropout(
                                    x,
                                    is_training,
                                    keep_prob=1.0 -
                                    self.config.layer_output_dropout_prob)
                                layer_output = prev_layer_output + x
                                if self.config.residual_structure == "original":
                                    layer_output = self.normalize(layer_output)
                                prev_layer_output = layer_output
                    self.all_encoder_layers.append(layer_output)

            self.sequence_output = prev_layer_output
            if self.config.residual_structure == "direct":
                self.sequence_output = self.normalize(self.sequence_output)

            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_dim, seq_dim, hidden_size] to a tensor of shape
            # [batch_dim, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = mtf.gather(self.sequence_output, 0,
                                                self.seq_dim)
                self.pooled_output = mtf.layers.dense(
                    first_token_tensor,
                    reduced_dims=[self.model_dim],
                    new_dims=[self.model_dim],
                    activation=mtf.tanh,
                    kernel_initializer=self.dense_initializer,
                    use_bias=self.config.use_bias)
Ejemplo n.º 10
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
  """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
  shortcut = inputs

  filter_h_dim = mtf.Dimension("filter_height", 3)
  filter_w_dim = mtf.Dimension("filter_width", 3)
  one_h_dim = mtf.Dimension("filter_height", 1)
  one_w_dim = mtf.Dimension("filter_width", 1)

  if projection_shortcut is not None:
    filters_dim = mtf.Dimension("filtersp", filters)
    kernel = mtf.get_variable(
        inputs.mesh, "kernel", mtf.Shape(
            [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
    shortcut = projection_shortcut(inputs, kernel)

  # First conv block
  filters1_dim = mtf.Dimension("filters1", filters)
  kernel1 = mtf.get_variable(
      inputs.mesh, "kernel1", mtf.Shape(
          [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      kernel1,
      strides=[1, 1, 1, 1],
      padding="SAME",
      h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

  # TODO(nikip): Add Dropout?
  inputs = batch_norm_relu(inputs, is_training)

  # Second conv block
  filters2_dim = mtf.Dimension("filters2", 4*filters)
  kernel2 = mtf.get_variable(
      inputs.mesh, "kernel2", mtf.Shape(
          [filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      kernel2,
      strides=[1, 1, 1, 1],
      padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)

  inputs = batch_norm_relu(inputs, is_training)

  # Third wide conv filter block
  filters3_dim = mtf.Dimension("filters3", filters)
  filters3_kernel = mtf.get_variable(
      inputs.mesh, "wide_kernel", mtf.Shape(
          [one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
  inputs = mtf.conv2d_with_blocks(
      inputs,
      filters3_kernel,
      strides,
      padding="SAME",
      h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

  # TODO(nikip): Althought the original resnet code has this batch norm, in our
  # setup this is causing no gradients to be passed. Investigate further.
  # inputs = batch_norm_relu(inputs, is_training, relu=True)

  # TODO(nikip): Maybe add residual with a projection?
  return mtf.relu(
      shortcut + mtf.rename_dimension(
          inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
Ejemplo n.º 11
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.activation_type

        # We assume fixed vocab size for targets
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim,
                                                   self.length_dim]),
                                        name=name)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape([batch_dim, self.model_dim]))

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([self.targets_vocab_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       self.targets_vocab_dim)
        # Add positional embeddings
        x += mtf.reshape(self.create_positional_emb_2d(targets),
                         [self.length_dim, self.model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf.layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    self.inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_decoder_layers):
            layer_name = "decoder_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Self attention layer
                x += layer_prepostprocess_dropout(
                    mtf.layers.masked_local_attention_1d(
                        mtf.layers.layer_norm(x,
                                              self.model_dim,
                                              name="layer_norm_att"),
                        None,
                        self.kv_dim,
                        self.heads_dim,
                        block_length=hparams.block_length,
                        name="self_att"))
                # ffn layer
                x += layer_prepostprocess_dropout(
                    mtf.layers.dense_relu_dense(
                        mtf.layers.layer_norm(x,
                                              self.model_dim,
                                              name="layer_norm_ffn"),
                        self.feedforward_dim,
                        hparams.dropout,
                        dropout_broadcast_dims=[self.length_dim]))

        x = mtf.layers.layer_norm(x, self.model_dim, name="final_layer_norm")

        # Calculate the logits and loss.
        logits = mtf.layers.dense(x, self.outputs_vocab_dim, name="logits")
        soft_targets = mtf.one_hot(targets,
                                   self.outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.outputs_vocab_dim)
        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l

        # Reshape logits to original target shape.
        logits = mtf.reshape(
            logits,
            mtf.Shape([
                batch_dim, self.rows_dim, self.orig_cols_dim,
                self.channels_dim, self.outputs_vocab_dim
            ]))

        return logits, loss
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    rows_dim = mtf.Dimension("rows_size", 28)
    cols_dim = mtf.Dimension("cols_size", 28)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28, 1]),
        mtf.Shape([batch_dim, rows_dim, cols_dim, one_channel_dim]))

    fh_dim = mtf.Dimension("fh", 3)
    fw_dim = mtf.Dimension("fw", 3)
    filters1_dim = mtf.Dimension("filters1", FLAGS.num_filters)
    filters2_dim = mtf.Dimension("filters2", FLAGS.num_filters)
    filters3_dim = mtf.Dimension("filters3", FLAGS.num_filters)
    filters4_dim = mtf.Dimension("filters4", FLAGS.num_filters)
    filters5_dim = mtf.Dimension("filters5", FLAGS.num_filters)
    filters6_dim = mtf.Dimension("filters6", FLAGS.num_filters)

    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])
    kernel3 = mtf.get_variable(mesh, "kernel3",
                               [fh_dim, fw_dim, filters2_dim, filters3_dim])
    kernel4 = mtf.get_variable(mesh, "kernel4",
                               [fh_dim, fw_dim, filters3_dim, filters4_dim])
    kernel5 = mtf.get_variable(mesh, "kernel5",
                               [fh_dim, fw_dim, filters4_dim, filters5_dim])
    kernel6 = mtf.get_variable(mesh, "kernel6",
                               [fh_dim, fw_dim, filters5_dim, filters6_dim])

    x = mtf.relu(mtf.conv2d(x, kernel1, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel2, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel3, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel4, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel5, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.relu(mtf.conv2d(x, kernel6, strides=[1, 1, 1, 1], padding="SAME"))
    x = mtf.reduce_mean(x, reduced_dim=filters6_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)
    logits = mtf.Dimension("logits", 10)
    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-2:],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Ejemplo n.º 13
0
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.set_activation_type()
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)
    hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
    filter_h_dim = mtf.Dimension("filter_height", 7)
    filter_w_dim = mtf.Dimension("filter_width", 7)
    filters = mtf.Dimension("filters", hparams.filter_sizes[0])
    rows_dim = mtf.Dimension("rows_size", hparams.rows_size)
    cols_dim = mtf.Dimension("cols_size", hparams.cols_size)
    row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
    col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
    classes_dim = mtf.Dimension("classes", 10)
    channels_dim = mtf.Dimension("channels", 3)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    inputs = features["inputs"]
    x = mtf.import_tf_tensor(
        mesh, tf.reshape(inputs, [
            hparams.batch_size,
            hparams.row_blocks,
            hparams.rows_size // hparams.row_blocks,
            hparams.col_blocks,
            hparams.num_channels*hparams.cols_size // hparams.col_blocks,
            hparams.num_channels]),
        mtf.Shape(
            [batch_dim, row_blocks_dim, rows_dim,
             col_blocks_dim, cols_dim, channels_dim]))
    x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim,
                          rows_dim, cols_dim, channels_dim])

    x = mtf.to_float(x)
    initial_filters = mtf.get_variable(
        mesh, "init_filters",
        mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters]))
    x = mtf.conv2d_with_blocks(
        x,
        initial_filters,
        strides=[1, 1, 1, 1],
        padding="SAME",
        h_blocks_dim=None, w_blocks_dim=col_blocks_dim)

    x = batch_norm_relu(x, is_training)

    # Conv blocks
    # [block - strided block layer - strided block layer] x n
    for layer in range(hparams.num_layers):
      layer_name = "block_layer_%d" % layer
      with tf.variable_scope(layer_name):
        # Residual block layer
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[0],
            blocks=hparams.layer_sizes[0],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer1",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[1],
            blocks=hparams.layer_sizes[1],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer2",
            row_blocks_dim=None,
            col_blocks_dim=None)
        x = block_layer(
            inputs=x,
            filters=hparams.filter_sizes[2],
            blocks=hparams.layer_sizes[2],
            strides=[1, 1, 1, 1],
            is_training=is_training,
            name="block_layer3",
            row_blocks_dim=None,
            col_blocks_dim=None)

    # Calculate the logits and loss.
    out = x
    outputs = mtf.layers.dense(
        out, hidden_dim,
        reduced_dims=out.shape.dims[-5:],
        activation=mtf.relu, name="dense")

    # We assume fixed vocab size for targets
    labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim]))

    logits = mtf.layers.dense(outputs, classes_dim, name="logits")
    soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, classes_dim)

    # Reshape logits so it doesn't break inside t2t.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
    loss = mtf.reduce_mean(loss)
    return logits, loss
Ejemplo n.º 14
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 2
    lnc = nc // 2**downsampling_factor

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype(npdtype),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype(npdtype),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype(npdtype),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc],
                                  symmetric=False,
                                  dtype=npdtype)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)

    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False,
                                  dtype=npdtype)
    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype(npdtype),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype(npdtype),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype(npdtype),
                                 shape=[padded_sz_dim])
    kv_hr = [ky_hr, kz_hr, kx_hr]

    # kvec for prior blocks
    prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x)
    prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y)
    prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z)

    kvec_pr = flowpm.kernels.fftk(
        [nc // n_block_x, nc // n_block_y, nc // n_block_z],
        symmetric=False,
        dtype=npdtype)
    kx_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[0].squeeze().astype(npdtype),
                                 shape=[prior_sx_dim])
    ky_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[1].squeeze().astype(npdtype),
                                 shape=[prior_sy_dim])
    kz_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[2].squeeze().astype(npdtype),
                                 shape=[prior_sz_dim])
    kv_pr = [ky_pr, kz_pr, kx_pr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, x_dim, y_dim, z_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    ## Compute initial initial conditions distributed

    fieldvar = mtf.get_variable(mesh, 'linear', hr_shape)
    input_field = tf.placeholder(data.dtype, [
        batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x,
        nc // n_block_y, nc // n_block_z
    ])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape)
    linearop = mtf.assign(fieldvar, mtfinp)
    #
    field = fieldvar
    initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                          output_dtype=tf.float32,
                          output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                          name='my_dumb_reshape',
                          splittable_dims=part_shape[:-1] + hr_shape[:4])

    #
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=initc.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init(low,
                               high,
                               0.1,
                               kv_lr,
                               kv_hr,
                               halo_size,
                               hr_shape,
                               lr_shape,
                               part_shape[1:],
                               downsampling_factor=downsampling_factor,
                               antialias=True)

        final_state = mtfpm.nbody(state,
                                  stages,
                                  lr_shape,
                                  hr_shape,
                                  kv_lr,
                                  kv_hr,
                                  halo_size,
                                  downsampling_factor=downsampling_factor)
    else:
        final_state = mtfpm.lpt_init(low,
                                     high,
                                     stages[-1],
                                     kv_lr,
                                     kv_hr,
                                     halo_size,
                                     hr_shape,
                                     lr_shape,
                                     part_shape[1:],
                                     downsampling_factor=downsampling_factor,
                                     antialias=True)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv_pr]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.placeholder(tf.float32, shape=())

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype)
    var_grads = [
        mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype)
    ]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0
Ejemplo n.º 15
0
def cifar_model(features, labels, mesh):
  """The model.

  Args:
    image: tf.Tensor with shape [batch, 32*32]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a mtf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
  features = copy.copy(features)
  batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
  row_blocks_dim = mtf.Dimension("row_blocks", 4)
  col_blocks_dim = mtf.Dimension("col_blocks", 4)
  rows_dim = mtf.Dimension("rows_size", 8)
  cols_dim = mtf.Dimension("cols_size", 8)

  classes_dim = mtf.Dimension("classes", 10)
  one_channel_dim = mtf.Dimension("one_channel", 3)


  # image = features['input']
  # with tf.device('/cpu:0'):
  image = features['image']
  labels = features['label']

  image = bnorm(image)

  x = mtf.import_tf_tensor(
      mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]),
      mtf.Shape(
          [batch_dim, row_blocks_dim, rows_dim,
           col_blocks_dim, cols_dim, one_channel_dim]))
  x = mtf.transpose(x, [
      batch_dim, row_blocks_dim, col_blocks_dim,
      rows_dim, cols_dim, one_channel_dim])

  # add some convolutional layers to demonstrate that convolution works.
  fh_dim = mtf.Dimension("fh", 7)
  fw_dim = mtf.Dimension("fw", 7)
  filters1_dim = mtf.Dimension("filters1", 32)
  filters2_dim = mtf.Dimension("filters2", 32)

  kernel1 = mtf.get_variable(
      mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim])
  kernel2 = mtf.get_variable(
      mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim])


  f1 = mtf.relu(mtf.conv2d_with_blocks(
      x, kernel1, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))


  f2 = mtf.relu(mtf.conv2d_with_blocks(
      f1, kernel2, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters3_dim = mtf.Dimension("filters3", 64)
  kernel3 = mtf.get_variable(
      mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim])  

  f3 = mtf.relu(mtf.conv2d_with_blocks(
      f2, kernel3, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters4_dim = mtf.Dimension("filters4", 64)
  kernel4 = mtf.get_variable(
      mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim])  

  f4 = mtf.relu(mtf.conv2d_with_blocks(
      f3, kernel4, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters5_dim = mtf.Dimension("filters5", 128)
  kernel5 = mtf.get_variable(
      mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim])  

  f5 = mtf.relu(mtf.conv2d_with_blocks(
      f4, kernel5, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))    

  filters6_dim = mtf.Dimension("filters6", 128)
  kernel6 = mtf.get_variable(
      mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim])  

  f6 = mtf.relu(mtf.conv2d_with_blocks(
      f5, kernel6, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters7_dim = mtf.Dimension("filters7", 128)
  kernel7 = mtf.get_variable(
      mesh, "kernel7", [fh_dim, fw_dim, filters6_dim, filters7_dim])  

  f7 = mtf.relu(mtf.conv2d_with_blocks(
      f6, kernel7, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters8_dim = mtf.Dimension("filters8", 128)
  kernel8 = mtf.get_variable(
      mesh, "kernel8", [fh_dim, fw_dim, filters7_dim, filters8_dim])  

  f8 = mtf.relu(mtf.conv2d_with_blocks(
      f7, kernel8, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters9_dim = mtf.Dimension("filters9", 128)
  kernel9 = mtf.get_variable(
      mesh, "kernel9", [fh_dim, fw_dim, filters8_dim, filters9_dim])  

  f9 = mtf.relu(mtf.conv2d_with_blocks(
      f8, kernel9, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters10_dim = mtf.Dimension("filters10", 128)
  kernel10 = mtf.get_variable(
      mesh, "kernel10", [fh_dim, fw_dim, filters9_dim, filters10_dim])  

  f10 = mtf.relu(mtf.conv2d_with_blocks(
      f9, kernel10, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))                              
 

  filters11_dim = mtf.Dimension("filters11", 256)
  kernel11 = mtf.get_variable(
      mesh, "kernel11", [fh_dim, fw_dim, filters10_dim, filters11_dim])  

  f11 = mtf.relu(mtf.conv2d_with_blocks(
      f10, kernel11, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters12_dim = mtf.Dimension("filters12", 256)
  kernel12 = mtf.get_variable(
      mesh, "kernel12", [fh_dim, fw_dim, filters11_dim, filters12_dim])  

  f12 = mtf.relu(mtf.conv2d_with_blocks(
      f11, kernel12, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))                                            
 

  filters13_dim = mtf.Dimension("filters13", 256)
  kernel13 = mtf.get_variable(
      mesh, "kernel13", [fh_dim, fw_dim, filters12_dim, filters13_dim])  

  f13 = mtf.relu(mtf.conv2d_with_blocks(
      f12, kernel13, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))     

  filters14_dim = mtf.Dimension("filters14", 256)
  kernel14 = mtf.get_variable(
      mesh, "kernel14", [fh_dim, fw_dim, filters13_dim, filters14_dim])  

  f14 = mtf.relu(mtf.conv2d_with_blocks(
      f13, kernel14, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))   

  filters15_dim = mtf.Dimension("filters15", 256)
  kernel15 = mtf.get_variable(
      mesh, "kernel15", [fh_dim, fw_dim, filters14_dim, filters15_dim])  

  f15 = mtf.relu(mtf.conv2d_with_blocks(
      f14, kernel15, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))

  filters16_dim = mtf.Dimension("filters16", 256)
  kernel16 = mtf.get_variable(
      mesh, "kernel16", [fh_dim, fw_dim, filters15_dim, filters16_dim])  
  f16 = mtf.relu(mtf.conv2d_with_blocks(
      f15, kernel16, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))  

  filters17_dim = mtf.Dimension("filters17", 256)
  kernel17 = mtf.get_variable(
      mesh, "kernel17", [fh_dim, fw_dim, filters16_dim, filters17_dim])  

  f17 = mtf.relu(mtf.conv2d_with_blocks(
      f16, kernel17, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) 

  filters18_dim = mtf.Dimension("filters18", 256)
  kernel18 = mtf.get_variable(
      mesh, "kernel18", [fh_dim, fw_dim, filters17_dim, filters18_dim])  

  f18 = mtf.relu(mtf.conv2d_with_blocks(
      f17, kernel18, strides=[1, 1, 1, 1], padding="SAME",
      h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim))        

  x = mtf.reduce_mean(f18, reduced_dim=filters18_dim)

  # add some fully-connected dense layers.
  hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
  hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

  h1 = mtf.layers.dense(
      x, hidden_dim1,
      reduced_dims=x.shape.dims[-4:],
      activation=mtf.relu, name="hidden1")
  h2 = mtf.layers.dense(
      h1, hidden_dim2,
      activation=mtf.relu, name="hidden2")

  hidden_dim3 = mtf.Dimension("hidden3", FLAGS.hidden_size)
  hidden_dim4 = mtf.Dimension("hidden4", FLAGS.hidden_size)
  hidden_dim5 = mtf.Dimension("hidden5", FLAGS.hidden_size)
  hidden_dim6 = mtf.Dimension("hidden6", FLAGS.hidden_size)
  hidden_dim7 = mtf.Dimension("hidden7", FLAGS.hidden_size)
  hidden_dim8 = mtf.Dimension("hidden8", FLAGS.hidden_size)

  h3 = mtf.layers.dense(
      h2, hidden_dim3,
      activation=mtf.relu, name="hidden3")

  h4 = mtf.layers.dense(
      h3, hidden_dim4,
      activation=mtf.relu, name="hidden4")

  h5 = mtf.layers.dense(
    h4, hidden_dim5,
    activation=mtf.relu, name="hidden5")

  h6 = mtf.layers.dense(
    h5, hidden_dim6,
    activation=mtf.relu, name="hidden6")

  h7 = mtf.layers.dense(
    h6, hidden_dim7,
    activation=mtf.relu, name="hidden7") 

  h8 = mtf.layers.dense(
    h7, hidden_dim8,
    activation=mtf.relu, name="hidden8")                        

  logits = mtf.layers.dense(h8, classes_dim, name="logits")
  
  if labels is None:
    loss = None
  else:
    labels = mtf.import_tf_tensor(
        mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim]))
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, mtf.one_hot(labels, classes_dim), classes_dim)
    loss = mtf.reduce_mean(loss)
  return logits, loss
Ejemplo n.º 16
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    #
    # Begin simulation

    ## Compute initial initial conditions distributed
    #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    fieldvar = mtf.get_variable(mesh, 'linear', part_shape)
    input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    linearop = mtf.assign(fieldvar, mtfinp)

    #field = fieldvar
    initc = fieldvar

    print("initc : ", initc)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])
    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    print("mmy : ", mmy)
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    width = tf.placeholder(tf.float32, shape=())

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(width * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    #k_dims = [d.shape[0] for d in kv]
    #k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)
    #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype)
    #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype)
    #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    R0 = tf.placeholder(tf.float32, shape=())
    M0 = tf.placeholder(tf.float32, shape=())
    off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder(
        tf.float32, shape=data.shape)
    mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
    mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0)
    #diff = diff / 0.25
    #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one
    diff = (diff + mtfoff) / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype)
    var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
Ejemplo n.º 17
0
  def _layer_stack(self,
                   x,
                   layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None,
                   step_num=None,
                   encdec_tensors=None,
                   states=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      states: an optional list of Tensors (used in incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    is_incremental = (step_num is not None)
    def layer_prepostprocess_dropout(x):
      if is_incremental:
        return x
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
    num_layers = len(layers)
    num_layer_norms = num_layers + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    if is_incremental:
      states = list(states)
      new_states = []
    tf.logging.info("states = %s" % (states,))

    for lnum, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, lnum)):
        if layer_type == "att":
          # Self attention layer
          if is_incremental:
            y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), None,
                    self_attention_mask, self.kv_dim, self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="att"))
        elif layer_type == "enc_att":
          # Encoder-Decoder attention layer
          if is_incremental:
            # Encoder-Decoder attention layer
            q_var, o_var, k, v = encdec_tensors[lnum]
            x += mtf.layers.multihead_encdec_attention_incremental(
                normalize(x),
                q_var, o_var, k, v,
                encdec_attention_mask,
                name="enc_att")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), encoder_output,
                    encdec_attention_mask, self.kv_dim, self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="enc_att"))
        elif layer_type == "local_att":
          if is_incremental:
            y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="local_att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.masked_local_attention_1d(
                    normalize(x),
                    self.kv_dim, self.heads_dim,
                    window_size=hparams.local_attention_window_size,
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    length_per_split=mtf.tensor_dim_to_size_per_split(
                        hparams.layout, hparams.mesh_shape,
                        self.max_length_dim),
                    name="local_att"))
        elif layer_type == "compressed_att":
          if is_incremental:
            raise ValueError("compressed_att incremental not implemented")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_self_attention_memory_compressed(
                    normalize(x),
                    mask_right=True,
                    compression_factor=hparams.compression_factor,
                    kv_channels=self.kv_dim,
                    heads=self.heads_dim,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="compressed_att"))
        else:
          if is_incremental:
            # insert length dimension.
            x_shape = x.shape
            shape_with_length = mtf.Shape(
                x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
                + x_shape.dims[-1:])
            x = mtf.reshape(x, shape_with_length)
          # ffn layer
          x += layer_prepostprocess_dropout(
              self._feedforward_layer(normalize(x), layer_type, losses=losses))
          if is_incremental:
            # remove length dimension
            x = mtf.reshape(x, x_shape)

    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    if is_incremental:
      return x, new_states
    else:
      return x
Ejemplo n.º 18
0
 def _var(x, init):
     return mtf.get_variable(x.mesh,
                             f"activation-{random.randint(0, 2 ** 32):x}",
                             [],
                             initializer=tf.constant_initializer(init),
                             dtype=x.dtype)
Ejemplo n.º 19
0
def bottleneck_block(inputs,
                     filters,
                     is_training,
                     strides,
                     projection_shortcut=None,
                     row_blocks_dim=None,
                     col_blocks_dim=None):
    """Bottleneck block variant for residual networks with BN after convolutions.

  Args:
    inputs: a `mtf.Tensor` of shape
        `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`.
    filters: `int` number of filters for the first two convolutions. Note
        that the third and final convolution will use 4 times as many filters.
    is_training: `bool` for whether the model is in training mode.
    strides: `int` block stride. If greater than 1, this block will ultimately
        downsample the input.
    projection_shortcut: `function` to use for projection shortcuts (typically
        a 1x1 convolution to match the filter dimensions). If None, no
        projection is used and the input is passed as unchanged through the
        shortcut connection.
    row_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis
    col_blocks_dim: a mtf.Dimension, row dimension which is
        spatially partitioned along mesh axis

  Returns:
    The output `Tensor` of the block.
  """
    shortcut = inputs

    filter_h_dim = mtf.Dimension("filter_height", 3)
    filter_w_dim = mtf.Dimension("filter_width", 3)
    one_h_dim = mtf.Dimension("filter_height", 1)
    one_w_dim = mtf.Dimension("filter_width", 1)

    if projection_shortcut is not None:
        filters_dim = mtf.Dimension("filtersp", filters)
        kernel = mtf.get_variable(
            inputs.mesh, "kernel",
            mtf.Shape(
                [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim]))
        shortcut = projection_shortcut(inputs, kernel)

    # First conv block
    filters1_dim = mtf.Dimension("filters1", filters)
    kernel1 = mtf.get_variable(
        inputs.mesh, "kernel1",
        mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel1,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    # TODO(nikip): Add Dropout?
    inputs = batch_norm_relu(inputs, is_training)

    # Second conv block
    filters2_dim = mtf.Dimension("filters2", filters)
    kernel2 = mtf.get_variable(
        inputs.mesh, "kernel2",
        mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    kernel2,
                                    strides=[1, 1, 1, 1],
                                    padding="SAME",
                                    h_blocks_dim=row_blocks_dim,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training)

    # Third wide conv filter block
    filters3_dim = mtf.Dimension("filters3", filters)
    filters3_kernel = mtf.get_variable(
        inputs.mesh, "wide_kernel",
        mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim]))
    inputs = mtf.conv2d_with_blocks(inputs,
                                    filters3_kernel,
                                    strides,
                                    padding="SAME",
                                    h_blocks_dim=None,
                                    w_blocks_dim=col_blocks_dim)

    inputs = batch_norm_relu(inputs, is_training, relu=False)

    # TODO(nikip): Maybe add residual with a projection?
    return mtf.relu(inputs + mtf.rename_dimension(
        shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
Ejemplo n.º 20
0
  def _layer_stack(self,
                   x,
                   num_layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      num_layers: an integer
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
    num_layer_norms = num_layers * (2 if encoder_output is None else 3) + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    for layer in range(num_layers):
      with tf.variable_scope("layer_%d" % layer):
        # Self attention layer
        x += layer_prepostprocess_dropout(
            mtf.layers.multihead_attention(
                normalize(x), None,
                self_attention_mask, self.kv_dim, self.heads_dim,
                dropout=hparams.attention_dropout,
                dropout_broadcast_dims=[self.length_dim],
                name="self_attention"))
        if encoder_output is not None:
          # Encoder-Decoder attention layer
          x += layer_prepostprocess_dropout(
              mtf.layers.multihead_attention(
                  normalize(x), encoder_output,
                  encdec_attention_mask, self.kv_dim, self.heads_dim,
                  dropout=hparams.attention_dropout,
                  dropout_broadcast_dims=[self.length_dim],
                  name="encdec_attention"))
        # ffn layer
        x += layer_prepostprocess_dropout(
            self._feedforward_layer(normalize(x), losses=losses))
    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    return x
Ejemplo n.º 21
0
    def _decoder_layer_stack_incremental(self,
                                         x,
                                         step_num,
                                         encdec_tensors,
                                         self_attention_k,
                                         self_attention_v,
                                         encdec_attention_mask=None):
        """Decoder layer stack during inference.

    We are processing only one position at a time.

    The self-attention keys and values have already been computed for
    previous positions.  In addition to the decoder output, we need to
    produce the updated self-attention keys and values.

    If there is an encoder, then additional Tensors are supplied in
    encdec_tensors, which give us the keys and values for encoder-decoder
    attention as well as the weight matrices q_var and o_var.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, model_dim]
      step_num: an mtf integer Scalar
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v)
      self_attention_k: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      self_attention_v: an optional list of num_layers Tensors each with shape
        [batch, heads, memory_length, kv_channels]
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.

    Returns:
      y: a mtf.Tensor with shape [<batch_dims>, model_dim]
      new_self_attention_k: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_k
      new_self_attention_v: a list of num_layers mtf.Tensors, with the same
        shapes as the elements of self_attention_v

    Raises:
      ValueError: if hparams make no sense
    """
        hparams = self._hparams
        num_layers = hparams.num_decoder_layers
        num_layer_norms = num_layers * (2 if encdec_tensors is None else 3) + 1
        layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
        layer_norm_combined_var = mtf.get_variable(
            x.mesh,
            "layer_norm_scale",
            mtf.Shape([layer_norms_dim, self.model_dim]),
            initializer=tf.ones_initializer(),
            activation_dtype=x.dtype)
        layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)

        def normalize(x):
            scale = layer_norm_vars.pop(0)
            variance = mtf.reduce_mean(mtf.square(x),
                                       reduced_dim=self.model_dim)
            return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

        new_self_attention_k = []
        new_self_attention_v = []
        for layer in xrange(num_layers):
            with tf.variable_scope("layer_%d" % layer):
                # Self attention layer
                y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                    normalize(x),
                    prev_k=self_attention_k[layer],
                    prev_v=self_attention_v[layer],
                    step_num=step_num,
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="att")
                new_self_attention_k.append(new_k)
                new_self_attention_v.append(new_v)
                x += y
                if encdec_tensors is not None:
                    # Encoder-Decoder attention layer
                    q_var, o_var, k, v = encdec_tensors[layer]
                    x += mtf.layers.multihead_encdec_attention_incremental(
                        normalize(x),
                        q_var,
                        o_var,
                        k,
                        v,
                        encdec_attention_mask,
                        name="enc_att")
                # ffn layer
                x += self._feedforward_layer(normalize(x), layer)
        x = normalize(x)
        assert not layer_norm_vars
        return x, new_self_attention_k, new_self_attention_v
Ejemplo n.º 22
0
def recon_model(mesh,
                data,
                R0,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))
    print("\nfieldvar : \n", fieldvar)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            fieldvar,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3  #*nc**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.constant(R0)
    print("R0 in the recon_model : ", R0)

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    # Element-wise function that applies a Fourier kernel
    plambda = FLAGS.plambda

    def _cwise_logprob(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        logprob = galmean.log_prob(data)
        return -1 * logprob

    cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype)
    cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype)
    final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype)
    chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata],
                      output_dtype=tf.float32)  #
    chisq = mtf.reduce_sum(chisq)
    ##    #

    loss = chisq + prior

    def _cwise_sample(finalfield, data):
        galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield))
        sample = galmean.sample()
        return sample

    sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata],
                       output_dtype=tf.float32)  #
    fields = [fieldvar, sample]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
Ejemplo n.º 23
0
def gradient_based_subword_tokenization(x,
                                        length_dim,
                                        max_subword_length=4,
                                        downsample=None,
                                        use_offsets=False,
                                        consider_chars_as_blocks=False,
                                        use_block_pos_embedding=False,
                                        share_block_kernel=False,
                                        memory_embeddings=0,
                                        context=None,
                                        block_mixing_mode=None,
                                        activation="softmax",
                                        downsample_function="mean"):
    """Implements GBSWT from Charformer.

  Args:
    x: a Tensor containing length_dim
    length_dim: a Dimension
    max_subword_length: integer
    downsample: integer.
    use_offsets: boolean.
    consider_chars_as_blocks: boolean.
    use_block_pos_embedding: boolean.
    share_block_kernel: boolean.
    memory_embeddings: integer.
    context: Context.
    block_mixing_mode: Str for block mixing.
    activation: Str for block ranking.
    downsample_function: Str, supports mean/linformer for now.

  Returns:
    a Tensor with the same shape as x.

  Raises:
    ValueError: if channels or depth don't match.
  """
    # don't use this for now.
    del max_subword_length
    del memory_embeddings
    all_blocks = []
    all_scores = []
    tf.logging.info("GSW block layer")

    def _tile(x, n, tile_dim):
        # Simple tile function in MTF.
        return mtf.concat([x] * n, tile_dim.name)

    def _repeat(x, n, repeat_dim):
        # repeat function in MTF
        tmp_dim = mtf.Dimension("tmp", 1)
        expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
        x = mtf.reshape(x, expand_shape)
        x = _tile(x, n, tmp_dim)
        output_shape = []
        for dim in x.shape.dims:
            if dim.name == "tmp":
                continue
            if dim.name == repeat_dim.name:
                dim = mtf.Dimension(dim.name, dim.size * n)
            output_shape.append(dim)
        output_shape = mtf.Shape(output_shape)
        x = mtf.reshape(x, output_shape)
        return x

    def _combined_dim(dims):
        return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)

    # compute all subword blocks
    # TODO(yitay): handle offsets to get all blocks
    if activation == "sigtanh":
        # one score for sigmoid
        tmp_dim = mtf.Dimension("block_score", 2)
    else:
        tmp_dim = mtf.Dimension("block_score", 1)

    model_dim = x.shape[-1]
    subword_blocks_width = [2, 3, 4]

    if consider_chars_as_blocks:
        subword_blocks_width += [1]

    if share_block_kernel:
        block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
        block_kernel = mtf.get_variable(x.mesh,
                                        "block_kernel",
                                        block_kernel_shape,
                                        initializer=None,
                                        dtype=context.variable_dtype)
    else:
        block_kernel = None

    for subword_len in subword_blocks_width:
        if use_block_pos_embedding:
            # this is turn off by default. It is meant to support cases like
            # parameterized pooling or other features.
            block_len_dim = mtf.Dimension(length_dim.name, subword_len)
            # TODO(vqtran): Consider other positional embeddings.
            block_pos_emb = sinusoid_positional_embedding_weights(
                context.mesh, block_len_dim, x.shape[-1],
                context.variable_dtype.activation_dtype)
            block_pos_emb = _repeat(
                block_pos_emb, math.ceil(length_dim.size / float(subword_len)),
                block_len_dim)
        if use_offsets:
            offset_space = subword_len
        else:
            offset_space = 1
        for offsets in range(offset_space):
            if offsets > 0:
                xoff = mtf.shift(x, offsets, length_dim, wrap=False)
                if use_block_pos_embedding:
                    block_pos_emb = mtf.shift(block_pos_emb,
                                              offsets,
                                              block_pos_emb.shape[-2],
                                              wrap=False)
            else:
                xoff = x
            tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
            if length_dim.size % subword_len != 0:
                tf.logging.info("Not divisible by length")
                # add extra padding tokens
                pad_amt = int(subword_len) - int(length_dim.size % subword_len)
                kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
            else:
                kp = xoff

            if use_block_pos_embedding:
                kp += block_pos_emb

            bx = mtf.pool_tensor_1d(
                kp,
                pool_dim=kp.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(subword_len))
            block_score = mtf.layers.dense(bx, [tmp_dim],
                                           use_bias=False,
                                           name="bx",
                                           reduced_dims=[model_dim],
                                           variable_dtype=None,
                                           kernel_weights=block_kernel)

            expand_bx = _repeat(bx, subword_len, length_dim)
            expand_scores = _repeat(block_score, subword_len, length_dim)
            if offsets > 0:
                # add offset.
                expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [offsets, 0],
                                        length_dim.name)
            new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
            if new_len.size < length_dim.size:
                pad_amt = new_len.size - length_dim.size
                expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
                expand_scores = mtf.pad(expand_scores, [0, pad_amt],
                                        length_dim.name)
            elif new_len.size > length_dim.size:
                expand_bx = mtf.slice(expand_bx, 0, length_dim.size,
                                      length_dim.name)
                expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
                                          length_dim.name)

            new_tmp_dim = mtf.Dimension("extra_dim", 1)
            expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
            expand_scores_shape = mtf.Shape(expand_scores.shape.dims +
                                            [new_tmp_dim])
            expand_bx = mtf.reshape(expand_bx, expand_shape)
            expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
            all_blocks.append(expand_bx)
            all_scores.append(expand_scores)

    all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
    all_scores = mtf.concat(all_scores, new_tmp_dim.name)
    tf.logging.info(all_blocks)
    new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
    combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
    block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
    block_net = mtf.reshape(all_scores, block_net_shape)

    if block_mixing_mode == "score_attention":
        tf.logging.info("Using score attention")
        att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
        tf.logging.info(block_net)
        att = mtf.softmax(att, reduced_dim=att.shape[-1])
        block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
        tf.logging.info(block_net)

    if activation == "softmax":
        block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
    elif activation == "tanh":
        tf.logging.info("Using tanh")
        block_net = mtf.tanh(block_net)

    all_blocks = block_net * all_blocks
    all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
    output = all_blocks

    if downsample:
        output_length = output.shape.get_dim_by_name("length")
        if output_length.size % int(downsample) != 0:
            pad_amt = int(downsample) - int(
                output_length.size % int(downsample))
            output = mtf.pad(output, [0, pad_amt], output_length.name)
        if downsample_function == "mean":
            output = mtf.pool_tensor_1d(
                output,
                pool_dim=output.shape.get_dim_by_name("length"),
                reduce_fn=mtf.reduce_mean,
                pool_size=int(downsample))
        else:
            raise ValueError("Downsampling function not implemeneted.")

    return output
Ejemplo n.º 24
0
    def __init__(self,
                 mesh,
                 query_input_dim,
                 memory_input_dim,
                 output_dim,
                 key_dim,
                 value_dim,
                 query_heads_dims,
                 memory_heads_dims,
                 variable_dtype,
                 shared_kv=False,
                 combine_dims=True,
                 ensemble_dim=None):
        """Create attention parameters.

    combine_dims is a hack for faster execution.  The heads and key/value
    dimensions are combined in the variables and the computation.  The hack
    would not be necessary if XLA optimized einsum properly.

    Args:
      mesh: a Mesh
      query_input_dim: a Dimension
      memory_input_dim: a Dimension
      output_dim: a Dimension
      key_dim: a Dimension
      value_dim: a Dimension
      query_heads_dims: a list of Dimension
      memory_heads_dims: a list of Dimension
      variable_dtype: a mtf.VariableDType
      shared_kv: a boolean
      combine_dims: a boolean
      ensemble_dim: an optional Dimension
    """
        if shared_kv and key_dim != value_dim:
            raise ValueError("shared_kv requires key_dim == value_dim")
        self.query_input_dim = query_input_dim
        self.memory_input_dim = memory_input_dim
        self.output_dim = output_dim
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.query_heads_dims = query_heads_dims or []
        self.memory_heads_dims = memory_heads_dims or []
        self.shared_kv = shared_kv
        self.combine_dims = combine_dims
        if combine_dims:
            q_shape = [query_input_dim, _combined_dim(self.q_dims)]
            k_shape = [memory_input_dim, _combined_dim(self.k_dims)]
            v_shape = [memory_input_dim, _combined_dim(self.v_dims)]
            o_shape = [_combined_dim(self.o_dims), output_dim]
        else:
            q_shape = [query_input_dim] + self.q_dims
            k_shape = [memory_input_dim] + self.k_dims
            v_shape = [memory_input_dim] + self.v_dims
            o_shape = self.o_dims + [output_dim]
        q_init = tf.random_normal_initializer(stddev=(query_input_dim.size *
                                                      key_dim.size)**-0.5)
        kv_init = tf.random_normal_initializer(
            stddev=memory_input_dim.size**-0.5)
        o_init = tf.random_normal_initializer(
            stddev=mtf.Shape(self.query_heads_dims + [value_dim]).size**-0.5)
        if ensemble_dim:
            q_shape = [ensemble_dim] + q_shape
            k_shape = [ensemble_dim] + k_shape
            v_shape = [ensemble_dim] + v_shape
            o_shape = [ensemble_dim] + o_shape
        self.wq = mtf.get_variable(mesh,
                                   "q",
                                   q_shape,
                                   initializer=q_init,
                                   dtype=variable_dtype)
        if shared_kv:
            self.wkv = mtf.get_variable(mesh,
                                        "kv",
                                        k_shape,
                                        initializer=kv_init,
                                        dtype=variable_dtype)
        else:
            self.wk = mtf.get_variable(mesh,
                                       "k",
                                       k_shape,
                                       initializer=kv_init,
                                       dtype=variable_dtype)
            self.wv = mtf.get_variable(mesh,
                                       "v",
                                       v_shape,
                                       initializer=kv_init,
                                       dtype=variable_dtype)
        self.wo = mtf.get_variable(mesh,
                                   "o",
                                   o_shape,
                                   initializer=o_init,
                                   dtype=variable_dtype)
Ejemplo n.º 25
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.activation_type

        # We assume fixed vocab size for targets
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim,
                                                   self.length_dim]),
                                        name=name)

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([self.targets_vocab_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       self.targets_vocab_dim)

        # Add positional embeddings
        x += mtf.reshape(self.create_positional_emb_2d(targets),
                         [self.length_dim, self.model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf.layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    self.inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        if hparams.attention_type == "local1d_spatial":
            decoder_output = local_attention1d_spatial_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        elif hparams.attention_type == "local2d_spatial":
            decoder_output = local_attention2d_spatial_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        elif hparams.attention_type == "local1d":
            decoder_output = local_attention1d_masked_decoder(
                x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
        else:
            raise ValueError("Invalid attention type.")

        # Calculate the logits and loss.
        logits = mtf.layers.dense(decoder_output,
                                  self.outputs_vocab_dim,
                                  name="logits")
        # Need a reshape for logits
        logits = mtf.reshape(
            logits,
            mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim]))
        soft_targets = mtf.one_hot(targets,
                                   self.outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.outputs_vocab_dim)
        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l

        # Reshape logits to original target shape.
        logits = mtf.reshape(
            logits,
            mtf.Shape([
                batch_dim, self.rows_dim, self.orig_cols_dim,
                self.channels_dim, self.outputs_vocab_dim
            ]))

        return logits, loss
Ejemplo n.º 26
0
def rezero(x, scope, dtype):
    with tf.variable_scope(scope):
        g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype)
        return x * g
Ejemplo n.º 27
0
def mnist_model(image, labels, mesh):
    """The model.

  Args:
    image: tf.Tensor with shape [batch, 28*28]
    labels: a tf.Tensor with shape [batch] and dtype tf.int32
    mesh: a mtf.Mesh

  Returns:
    logits: a tf.Tensor with shape [batch, 10]
    loss: a mtf.Tensor with shape []
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    row_blocks_dim = mtf.Dimension("row_blocks", 4)
    col_blocks_dim = mtf.Dimension("col_blocks", 4)
    rows_dim = mtf.Dimension("rows_size", 7)
    cols_dim = mtf.Dimension("cols_size", 7)

    classes_dim = mtf.Dimension("classes", 10)
    one_channel_dim = mtf.Dimension("one_channel", 1)

    x = mtf.import_tf_tensor(
        mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]),
        mtf.Shape([
            batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
            one_channel_dim
        ]))
    x = mtf.transpose(x, [
        batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
        one_channel_dim
    ])

    # add some convolutional layers to demonstrate that convolution works.
    fh_dim = mtf.Dimension("fh", 9)
    fw_dim = mtf.Dimension("fw", 9)
    filters1_dim = mtf.Dimension("filters1", 16)
    filters2_dim = mtf.Dimension("filters2", 16)
    kernel1 = mtf.get_variable(mesh, "kernel1",
                               [fh_dim, fw_dim, one_channel_dim, filters1_dim])
    kernel2 = mtf.get_variable(mesh, "kernel2",
                               [fh_dim, fw_dim, filters1_dim, filters2_dim])

    f1 = mtf.relu(
        mtf.conv2d_with_blocks(x,
                               kernel1,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    f2 = mtf.relu(
        mtf.conv2d_with_blocks(f1,
                               kernel2,
                               strides=[1, 1, 1, 1],
                               padding="SAME",
                               h_blocks_dim=row_blocks_dim,
                               w_blocks_dim=col_blocks_dim))
    x = mtf.reduce_mean(f2, reduced_dim=filters2_dim)

    # add some fully-connected dense layers.
    hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size)
    hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size)

    h1 = mtf.layers.dense(x,
                          hidden_dim1,
                          reduced_dims=x.shape.dims[-4:],
                          activation=mtf.relu,
                          name="hidden1")
    h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2")
    logits = mtf.layers.dense(h2, classes_dim, name="logits")
    if labels is None:
        loss = None
    else:
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [FLAGS.batch_size]),
                                      mtf.Shape([batch_dim]))
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, mtf.one_hot(labels, classes_dim), classes_dim)
        loss = mtf.reduce_mean(loss)
    return logits, loss
Ejemplo n.º 28
0
 def talking_heads(self,
                   context,
                   inp,
                   name,
                   input_heads_dims,
                   output_heads_dims,
                   dynamic_projections_from=None):
     shared_dims = [d for d in input_heads_dims if d in output_heads_dims]
     reduced_dims = [
         d for d in input_heads_dims if d not in output_heads_dims
     ]
     new_dims = [d for d in output_heads_dims if d not in input_heads_dims]
     if not (reduced_dims or new_dims):
         # Output dimensions are same as input dimensions.  Return the input
         return inp
     elif dynamic_projections_from:
         # There are one or more dynamic talking-heads-projections
         with tf.variable_scope(name):
             # static projection - this is the same as the static projection in the
             # "else" case below.  We create the weight matrix with get_variable
             # instead of calling mtf.layers.dense() so that we can fold the
             # static projection into one of the dynamic projections.
             static_p_initializer = mtf.layers.VarianceScalingInitializer()(
                 reduced_dims, new_dims)
             static_p_shape = (context.model.ensemble_dims + shared_dims +
                               reduced_dims + new_dims)
             static_p = mtf.get_variable(inp.mesh,
                                         "kernel",
                                         static_p_shape,
                                         initializer=static_p_initializer,
                                         dtype=context.variable_dtype)
             ps = []
             for i, dp_from in enumerate(dynamic_projections_from):
                 kernel_initializer = mtf.layers.VarianceScalingInitializer(
                     self.dynamic_projections_init_scale /
                     mtf.Shape(reduced_dims).size)
                 ps.append(
                     mtf.layers.dense(
                         dp_from,
                         reduced_dims=[context.model.model_dim],
                         new_dims=shared_dims + reduced_dims + new_dims,
                         use_bias=False,
                         activation=None,
                         variable_dtype=context.variable_dtype,
                         name="%s_dynamic_%d" % (name, i),
                         expert_dims=context.model.ensemble_dims,
                         kernel_initializer=kernel_initializer))
             # Fold the static projection into one of the static projections.
             # Mathematically, we could add all the dynamic projections together
             #   here, but it would create a very large tensor which contained
             #   both the query-length and memory-length dimensions, and would
             #   probably be slower in practice.
             ps[0] += static_p
             return mtf.add_n([
                 mtf.einsum([inp, p], reduced_dims=reduced_dims) for p in ps
             ])
     else:
         # No dynamic projections.  Static talking-heads projection only
         return mtf.layers.dense(inp,
                                 reduced_dims=reduced_dims,
                                 new_dims=new_dims,
                                 use_bias=False,
                                 activation=None,
                                 variable_dtype=context.variable_dtype,
                                 name=name,
                                 expert_dims=context.model.ensemble_dims +
                                 shared_dims)
Ejemplo n.º 29
0
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None):
    # x :: [batch, seq, n_embd]
    x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh

    # n_state is the same as config["n_embd"], which is also the same as dim_embd.
    assert n_state.size % params["n_head"] == 0

    dim_heads = mtf.Dimension("heads", params["n_head"])

    num_mem_kv = params.get("num_mem_kv", 0)
    use_num_mem_kv = num_mem_kv > 0

    with tf.variable_scope(scope):
        # Compute attention inputs
        dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"])
        mtfparams = mtf.transformer.attention.attention_params_simple(
            x.mesh,
            io_dim=dim_embd,
            kv_dim=dim_kv,
            heads_dim=dim_heads,
            variable_dtype=variable_dtype
        )
        q = mtfparams.compute_q(x)
        k = mtfparams.compute_k(x)
        v = mtfparams.compute_v(x)

        if is_incremental_inference(context):
            one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype)
            inv_one_hot = 1.0 - one_hot
            old_k, old_v = context.get_states(2)
            k = old_k * inv_one_hot + k * one_hot
            v = old_v * inv_one_hot + v * one_hot

        if exists(context):
            context.record_new_states([k, v])

        with tf.variable_scope("attention"):
            if attention_type == "local":
                # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
                radius = params.get("local_attention_radius", 256)

                if is_incremental_inference(context):
                    q *= one_hot

                a = mtf_transformer.attention.local_attention_1d(
                    q, k, v,
                    length_dim=k.shape[1],
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    radius=radius,
                    length_dim_num_splits=1,
                    fully_autoregressive=params["causal"],
                    attention_kwargs={},
                )

                if is_incremental_inference(context):
                    a = mtf.gather(a, context.position - 1, dim_seq)

            elif attention_type == "global":

                # TODO: pass in fake context
                # Broadcast mask bias across batch and heads
                if exists(bias):
                    if not is_incremental_inference(context):
                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]])
                    else:
                        # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
                        bias = mtf.gather(bias, context.position - 1, dim_seq)
                        broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]])

                # memory key / values, from all-attention paper
                if use_num_mem_kv:
                    k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh)

                k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)
                v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)

                attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0

                a = mtf_transformer.attention.attention(
                    q, k, v,
                    memory_length_dim=memory_length_dim,
                    key_dim=dim_kv,
                    value_dim=dim_kv,
                    bias=broadcasted_bias,
                    dropout_rate=attn_dropout_rate
                )

            elif attention_type == "linear":
                linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention
                a = linear_attn_fn(q, k, v)

            else:
                raise NotImplementedError("Unknown attention type {}!".format(attention_type))

        with tf.variable_scope("compute_output"):
            a = mtfparams.compute_output(a, x_shape)

        with tf.variable_scope("compute_output_bias"):
            b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0),
                                 master_dtype=variable_dtype.master_dtype,
                                 slice_dtype=variable_dtype.slice_dtype,
                                 activation_dtype=variable_dtype.activation_dtype)
            a += b

        if params["mode"] == "train" and params["res_dropout"] > 0:
            a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout")
        return a
Ejemplo n.º 30
0
  def mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    tf.logging.info("features = %s" % features)
    hparams = self._hparams
    activation_dtype = self.activation_type

    # We assume fixed vocab size for targets
    targets = tf.to_int32(features["targets"])

    # Image preprocessing, reshape into a 1D sequence and shift right.
    length = hparams.img_len*hparams.img_len*hparams.num_channels
    targets = tf.reshape(targets, [hparams.batch_size, length])
    shifted_targets = common_layers.shift_right_2d(targets)

    # Declare all the dimensions
    batch_dim = mtf.Dimension("batch", hparams.batch_size)

    def import_to_batch_by_length(x, name):
      return mtf.import_tf_tensor(
          mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name)

    targets = import_to_batch_by_length(targets, "targets")
    shifted_targets = import_to_batch_by_length(
        shifted_targets, "shifted_targets")

    extra_losses = []

    # Create targets content and position embeddings.
    # Create embedding var for targets and positions and do a gather.
    targets_embedding_var = mtf.get_variable(
        mesh, "targets_embedding",
        mtf.Shape([self.targets_vocab_dim, self.model_dim]),
        initializer=tf.random_normal_initializer(),
        activation_dtype=activation_dtype)

    x = mtf.gather(targets_embedding_var,
                   shifted_targets, self.targets_vocab_dim)

    # Add positional embeddings
    x += mtf.reshape(self.create_positional_emb_2d(targets),
                     [self.length_dim, self.model_dim])

    # If conditional and input is given, add the input embedding to the target.
    # TODO(nikip): Verify conditional.
    if self.has_input and not hparams.unconditional:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = import_to_batch_by_length(inputs, "inputs")

      # Input embeddings
      inputs_embedding_var = mtf.layers.embedding(
          mesh, "input_embedding",
          mtf.Shape([self.inputs_vocab_dim, self.model_dim]),
          activation_dtype=activation_dtype)
      inputs_emb = mtf.gather(
          inputs_embedding_var, inputs, self.inputs_vocab_dim)
      x += inputs_emb

    # Image Transformer Decoder
    # [ self attention - ffn - residual + dropout] x n
    if hparams.attention_type == "local1d_spatial":
      decoder_output = local_attention1d_spatial_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    elif hparams.attention_type == "local2d_spatial":
      decoder_output = local_attention2d_spatial_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    elif hparams.attention_type == "local1d":
      decoder_output = local_attention1d_masked_decoder(
          x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams)
    else:
      raise ValueError("Invalid attention type.")

    # Calculate the logits and loss.
    logits = mtf.layers.dense(
        decoder_output, self.outputs_vocab_dim, name="logits")
    # Need a reshape for logits
    logits = mtf.reshape(
        logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim]))
    soft_targets = mtf.one_hot(
        targets, self.outputs_vocab_dim, dtype=activation_dtype)
    loss = mtf.layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.outputs_vocab_dim)
    loss = mtf.reduce_mean(loss)
    for l in extra_losses:
      loss += l

    # Reshape logits to original target shape.
    logits = mtf.reshape(
        logits,
        mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim,
                   self.channels_dim, self.outputs_vocab_dim]))

    return logits, loss
Ejemplo n.º 31
0
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None):
    """A GPT style model implemented in mesh tensorflow."""

    x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features)

    if is_incremental_inference(context):
        # reshape inputs if in inference mode
        x = mtf.gather(x, context.position - 1, sequence_dim)
        x = mtf.reshape(x, [batch_dim])

    use_axial_pos_emb = params["axial_pos_emb"] is not None

    if not use_axial_pos_emb:
        # Use standard position encoding
        wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]),
                               initializer=tf.random_normal_initializer(stddev=0.01),
                               master_dtype=variable_dtype.master_dtype,
                               slice_dtype=variable_dtype.slice_dtype,
                               activation_dtype=variable_dtype.activation_dtype)
    else:
        wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)

    # Text encoding
    wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]),
                           initializer=tf.random_normal_initializer(stddev=0.02),
                           master_dtype=variable_dtype.master_dtype,
                           slice_dtype=variable_dtype.slice_dtype,
                           activation_dtype=variable_dtype.activation_dtype)

    with tf.variable_scope("token_embd"):
        # Text embedding
        h = mtf.gather(wte, x, vocab_dim)
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout")

    with tf.variable_scope("pos_embd"):
        # Positional embedding
        position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else (
                context.position - 1)
        pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
        if params["embed_dropout"] > 0 and params["mode"] == "train":
            pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout")
        h += pos_emb

    aux_losses = 0  # instantiate auxiliary losses (for MOE models)

    for layer in range(params["n_layer"]):
        # attn blocks
        share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True
        block_scope = f"h{layer}" if not share_parameters else ""

        block_fn = block(params=params, scope=block_scope, layer_num=layer,
                         bias=other_features["attn_bias"],
                         sequence_dim=sequence_dim,
                         memory_length_dim=other_features["memory_length_dim"],
                         variable_dtype=variable_dtype,
                         context=context)

        # If true and in train mode, enable gradient checkpointing
        recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True
        h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h])
        aux_losses += loss

    no_weight_tie_emb = params["no_weight_tie"] == True
    if no_weight_tie_emb:
        with tf.variable_scope("wte_final_linear"):
            logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params)
    else:
        # Layer normalize & affine transform
        h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
        seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1)
        with tf.variable_scope("wte_final_einsum"):
            # Equivalent to tf.matmul
            logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim])

    if params["mode"] in ["train", "eval"]:
        labels = mtf_features["labels"]
        z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy

        # Go to full precision for the logits 
        logits = mtf.cast(logits, tf.float32)

        use_entmax_loss = params.get("entmax_loss", False)
        loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits

        with tf.variable_scope("xentropy_final"):
            loss_batch = loss_fn(logits=logits, targets=labels,
                                 vocab_dim=logits.shape[-1], z_loss=z_loss)

        # For non-autoregressive models (masked language modeling training)
        # Make sure labels with padding tokens are not counted in the loss
        if not params["causal"]:
            padding_id = params.get("padding_id", 0)
            loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch))

        with tf.variable_scope("reduce_mean_final"):
            loss = mtf.reduce_mean(loss_batch)

        loss += aux_losses  # Add on auxiliary losses (currently only used for MoE)
        loss /= params["num_microbatches"]
        # Convert to train dtype
        loss = mtf.cast(loss, variable_dtype.slice_dtype)
    else:
        loss = None
        loss_batch = None

    # Cast back to checkpoint dtype
    logits = mtf.cast(logits, variable_dtype.master_dtype)
    return logits, loss, loss_batch
Ejemplo n.º 32
0
def recon_model(mesh,
                data,
                bparams,
                ipkerror,
                R0,
                x0,
                nc=FLAGS.nc,
                bs=FLAGS.box_size,
                batch_size=FLAGS.batch_size,
                a0=FLAGS.a0,
                a=FLAGS.af,
                nsteps=FLAGS.nsteps,
                dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """

    b1, b2, bs2 = bparams
    kerror, perror = ipkerror[0].astype(np.float32), ipkerror[1].astype(
        np.float32)

    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('..//data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('..//data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])
    pke_dim = mtf.Dimension("epk", len(perror))
    pkerror = mtf.import_tf_tensor(mesh,
                                   perror.astype(npdtype),
                                   shape=[pke_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]
    #
    # Begin simulation

    if x0 is None:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.random_normal_initializer(
                                        mean=0.0, stddev=1, seed=None))
    else:
        fieldvar = mtf.get_variable(mesh,
                                    'linear',
                                    part_shape,
                                    initializer=tf.constant_initializer(x0))

    state = mtfpm.lpt_init_single(
        fieldvar,
        a0,
        kv_lr,
        halo_size,
        lr_shape,
        hr_shape,
        part_shape[1:],
        antialias=True,
    )
    final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr,
                                     halo_size)

    # paint the field
    final_field = mtf.zeros(mesh, shape=part_shape)
    final_field = mcomp.cic_paint_fr(final_field, final_state, part_shape,
                                     hr_shape, halo_size, splittables, mesh)

    ##
    #Get the fields for bias
    hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables,
                              mesh)
    mstate = mpm.mtf_indices(hr_field.mesh,
                             shape=part_shape[1:],
                             dtype=tf.float32)
    X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate],
                   output_shape=[batch_dim] + mstate.shape[:])
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    tfnc, tfbs = cswisef.float_to_mtf(nc * 1., mesh,
                                      scalar), cswisef.float_to_mtf(
                                          bs, mesh, scalar)

    #
    initc = fieldvar
    d0 = initc - mtf.reduce_mean(initc)
    #
    d2 = initc * initc
    d2 = d2 - mtf.reduce_mean(d2)
    #
    cfield = mesh_utils.r2c3d(d0, k_dims_pr, dtype=cdtype)
    shearfield = mtf.zeros(mesh, shape=part_shape)
    shearfield = shear(shearfield, cfield, kv, tfnc, tfbs)
    s2 = shearfield - mtf.reduce_mean(shearfield)

    dread = mcomp.cic_readout_fr(d0, [X],
                                 hr_shape=hr_shape,
                                 halo_size=halo_size,
                                 splittables=splittables,
                                 mesh=mesh)
    d2read = mcomp.cic_readout_fr(d2, [X],
                                  hr_shape=hr_shape,
                                  halo_size=halo_size,
                                  splittables=splittables,
                                  mesh=mesh)
    s2read = mcomp.cic_readout_fr(s2, [X],
                                  hr_shape=hr_shape,
                                  halo_size=halo_size,
                                  splittables=splittables,
                                  mesh=mesh)

    ed, ed2, es2 = mtf.zeros(mesh, shape=part_shape), mtf.zeros(
        mesh, shape=part_shape), mtf.zeros(mesh, shape=part_shape)
    ed = mcomp.cic_paint_fr(ed,
                            final_state,
                            output_shape=part_shape,
                            hr_shape=hr_shape,
                            halo_size=halo_size,
                            splittables=splittables,
                            mesh=mesh,
                            weights=dread)
    ed2 = mcomp.cic_paint_fr(ed2,
                             final_state,
                             output_shape=part_shape,
                             hr_shape=hr_shape,
                             halo_size=halo_size,
                             splittables=splittables,
                             mesh=mesh,
                             weights=d2read)
    es2 = mcomp.cic_paint_fr(es2,
                             final_state,
                             output_shape=part_shape,
                             hr_shape=hr_shape,
                             halo_size=halo_size,
                             splittables=splittables,
                             mesh=mesh,
                             weights=s2read)

    model = ed * b1 + ed2 * b2 + es2 * bs2
    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)
    diff = model - mtfdata

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3  #* nc**3

    # Total loss
    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)

    def _cwise_diff(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(x=kk,
                                                 x_ref_min=kerror.min(),
                                                 x_ref_max=kerror.max(),
                                                 y_ref=pk)
        priormesh = tf.reshape(pkmesh, kshape)
        priormesh = tf.cast(priormesh**0.5, kfield.dtype)
        return kfield / priormesh

    cdiff = mtf.cwise(_cwise_diff, [cdiff, pkerror] + kv, output_dtype=cdtype)

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    fields = [fieldvar, final_field, model]
    metrics = [chisq, prior, loss]

    return fields, metrics, kv
Ejemplo n.º 33
0
  def _layer_stack(self,
                   x,
                   layers,
                   encoder_output=None,
                   self_attention_mask=None,
                   encdec_attention_mask=None,
                   losses=None,
                   step_num=None,
                   encdec_tensors=None,
                   states=None):
    """Encoder or decoder stack.

    Args:
      x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
      layers: an list of strings
      encoder_output: an optional mtf.Tensor with shape
        [<batch_dims>, encoder_length_dim, model_dim]
      self_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, memory_length_dim] containing values 0 or -inf.
      encdec_attention_mask: an optional mtf.Tensor with shape
        [batch, length_dim, encoder_length_dim] containing values 0 or -inf.
      losses: a list to be appended-to
      step_num: an optional mtf integer Scalar (used in incrmenental mode)
      encdec_tensors: an optional list of num_layers tuples, each of the form
        (q_var, o_var, k, v), (used in incremental mode)
      states: an optional list of Tensors (used in incremental mode)
    Returns:
      a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim]
    Raises:
      ValueError: if hparams make no sense
    """
    hparams = self._hparams
    is_incremental = (step_num is not None)
    mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    def layer_prepostprocess_dropout(x):
      if is_incremental:
        return x
      return mtf.dropout(
          x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
    num_layers = len(layers)
    num_layer_norms = num_layers + 1
    layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms)
    layer_norm_combined_var = mtf.get_variable(
        x.mesh,
        "layer_norm_scale",
        mtf.Shape([layer_norms_dim, self.model_dim]),
        initializer=tf.ones_initializer(),
        activation_dtype=x.dtype)
    layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim)
    def normalize(x):
      scale = layer_norm_vars.pop(0)
      variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim)
      return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale

    if is_incremental:
      states = list(states)
      new_states = []
    tf.logging.info("states = %s" % (states,))

    for lnum, layer_type in enumerate(layers):
      with tf.variable_scope("%s_%d" % (layer_type, lnum)):
        if layer_type == "att":
          # Self attention layer
          if is_incremental:
            y, new_k, new_v = mtf.layers.multihead_self_attention_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), None,
                    self_attention_mask, self.kv_dim, self.heads_dim,
                    is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="att"))
        elif layer_type == "enc_att":
          # Encoder-Decoder attention layer
          if is_incremental:
            # Encoder-Decoder attention layer
            q_var, o_var, k, v = encdec_tensors[lnum]
            x += mtf.layers.multihead_encdec_attention_incremental(
                normalize(x),
                q_var, o_var, k, v,
                encdec_attention_mask,
                name="enc_att")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_attention(
                    normalize(x), encoder_output,
                    encdec_attention_mask, self.kv_dim, self.heads_dim,
                    is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="enc_att"))
        elif layer_type == "local_att":
          if is_incremental:
            y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
                normalize(x),
                prev_k=states.pop(0),
                prev_v=states.pop(0),
                step_num=step_num,
                master_dtype=self.master_dtype,
                slice_dtype=self.slice_dtype,
                name="local_att")
            new_states.append(new_k)
            new_states.append(new_v)
            x += y
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.masked_local_attention_1d(
                    normalize(x),
                    self.kv_dim, self.heads_dim, is_training,
                    window_size=hparams.local_attention_window_size,
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    length_per_split=mtf.tensor_dim_to_size_per_split(
                        hparams.layout, hparams.mesh_shape,
                        self.max_length_dim),
                    name="local_att"))
        elif layer_type == "compressed_att":
          if is_incremental:
            raise ValueError("compressed_att incremental not implemented")
          else:
            x += layer_prepostprocess_dropout(
                mtf.layers.multihead_self_attention_memory_compressed(
                    normalize(x),
                    mask_right=True,
                    compression_factor=hparams.compression_factor,
                    kv_channels=self.kv_dim,
                    heads=self.heads_dim,
                    is_training=is_training,
                    dropout=hparams.attention_dropout,
                    dropout_broadcast_dims=[self.length_dim],
                    master_dtype=self.master_dtype,
                    slice_dtype=self.slice_dtype,
                    name="compressed_att"))
        else:
          if is_incremental:
            # insert length dimension.
            x_shape = x.shape
            shape_with_length = mtf.Shape(
                x_shape.dims[:-1] + [mtf.Dimension("length", 1)]
                + x_shape.dims[-1:])
            x = mtf.reshape(x, shape_with_length)
          # ffn layer
          x += layer_prepostprocess_dropout(
              self._feedforward_layer(normalize(x), layer_type, losses=losses))
          if is_incremental:
            # remove length dimension
            x = mtf.reshape(x, x_shape)

    x = layer_prepostprocess_dropout(normalize(x))
    assert not layer_norm_vars
    if is_incremental:
      return x, new_states
    else:
      return x
Ejemplo n.º 34
0
  def __init__(self,
               mesh,
               query_input_dim,
               memory_input_dim,
               output_dim,
               key_dim,
               value_dim,
               query_heads_dims,
               memory_heads_dims,
               variable_dtype,
               shared_kv=False,
               no_query=False,
               combine_dims=True,
               ensemble_dim=None,
               keep_query_heads_dims=False,
               fold_scaling_into_initializer=True,
               context=None,
               experts_hparams=None,
               expert_computation="qkv"):
    super(ExpertsAttentionParams, self).__init__(
        mesh=mesh,
        query_input_dim=query_input_dim,
        memory_input_dim=memory_input_dim,
        output_dim=output_dim,
        key_dim=key_dim,
        value_dim=value_dim,
        query_heads_dims=query_heads_dims,
        memory_heads_dims=memory_heads_dims,
        variable_dtype=variable_dtype,
        shared_kv=shared_kv,
        no_query=no_query,
        combine_dims=combine_dims,
        ensemble_dim=ensemble_dim,
        keep_query_heads_dims=keep_query_heads_dims,
        fold_scaling_into_initializer=fold_scaling_into_initializer,
        make_attention_vars=False)

    self.context = context
    self.expert_computation = expert_computation

    # Unless we want to compute both q and kv, we can use the normal MoE
    # settings.
    if expert_computation == "qkv":
      experts_attention_compute_qkv = True
    elif expert_computation in ["q", "kv"]:
      experts_attention_compute_qkv = False
      if expert_computation == "q":
        # Always assume shared_kv.
        self.wkv = mtf.get_variable(
            self.mesh,
            "kv",
            self.k_shape,
            initializer=tf.random_normal_initializer(
                stddev=self.memory_input_dim.size ** -0.5),
            dtype=self.variable_dtype)
      else:  # Computing kv with experts.
        self.wq = mtf.get_variable(
            self.mesh,
            "q",
            self.q_shape,
            initializer=tf.random_normal_initializer(
                stddev=self.query_input_dim.size ** -0.5),
            dtype=self.variable_dtype)
    else:
      raise ValueError("Invalid expert computation mode: {}".format(
          expert_computation))

    # ExpertsAttention, for simplicitly, asserts that combine_dims is True, and
    # for efficiency, that shared_kv is True.
    if not self.combine_dims:
      raise ValueError("combine_dims must be True for ExpertsAttention.")
    if not self.shared_kv:
      raise ValueError("shared_kv must be True for ExpertsAttention.")
    if mtf.layers.unit_scaling_convention():
      raise NotImplementedError

    # Now replace "heads" dim with the "d_model" name to avoid conflicts when
    # we want to partition both "experts_hidden" and "heads".
    moe_output_dims = mtf.Dimension("d_model", self.q_shape[-1].size)

    tf.logging.info("ExpertsAttention moe_hidden_size: {}".format(
        experts_hparams.hidden_size))
    tf.logging.info("moe_output_dims: {}".format(moe_output_dims))
    self.moe_layer = mtf.transformer.moe.MoE1D(
        moe_gating=experts_hparams.moe_gating,
        num_experts=experts_hparams.num_experts,
        loss_coef=experts_hparams.loss_coef,
        group_size=experts_hparams.group_size,
        min_expert_capacity=experts_hparams.min_expert_capacity,
        capacity_factor_train=experts_hparams.capacity_factor_train,
        capacity_factor_eval=experts_hparams.capacity_factor_eval,
        switch_policy_train=experts_hparams.switch_policy_train,
        switch_policy_eval=experts_hparams.switch_policy_eval,
        switch_dropout=experts_hparams.switch_dropout,
        switch_temperature=experts_hparams.switch_temperature,
        switch_jitter=experts_hparams.switch_jitter,
        ntlb_top_k=experts_hparams.ntlb_top_k,
        hidden_size=experts_hparams.hidden_size,
        output_dim=moe_output_dims,
        use_experts_attention=experts_attention_compute_qkv,
        activation=experts_hparams.activation,
        z_loss=experts_hparams.z_loss)
Ejemplo n.º 35
0
def hybrid_attention(q,
                     k,
                     v,
                     context,
                     memory_length_dim,
                     key_dim,
                     value_dim,
                     bias=None,
                     dropout_rate=0.0,
                     dropout_broadcast_dims=None,
                     extra_logit=None):
  """Dot-product attention - doesn't use positional dimensions.

  key_dim is a Dimension representing the channels in the queries and keys
  value_dim is a Dimension representing the channels in values
  memory_length_dim is a Dimension representing the different key/value pairs.

  Dimensions of q: other_query_dims + {key_dim}
  Dimensions of k: other_memory_dims + {memory_length_dim, key_dim}
  Dimensions of v: other_memory_dims + {memory_length_dim, value_dim}
  other_memory_dims is a subset of other_query_dims

  Typically, other_query_dims={batch, heads, length}
  Typically, other_memory_dims={batch, heads}

  Args:
    q: a Tensor
    k: a Tensor
    v: a Tensor
    context: context of the attention layer.
    memory_length_dim: a Dimension
    key_dim: a Dimension
    value_dim: a Dimension
    bias: a Tensor to be added into the attention logits.
    dropout_rate: a float.
    dropout_broadcast_dims: an optional list of mtf.Dimension
    extra_logit: an optional scalar or tensor

  Returns:
    Tensor with shape q.shape - key_dim + value_dim
  """
  logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim])
  if bias is not None:
    logits += bias

  query_length_dim = mtf.Dimension("length", memory_length_dim.size)
  doubly_coeff = mtf.get_variable(
      context.mesh, "doubly_coeff", [],
      initializer=tf.constant_initializer(0.5),
      dtype=context.variable_dtype)
  doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.)

  upper_weights = mtf.softmax(
      logits, memory_length_dim, extra_logit=extra_logit)

  lower_log_weights = mtf.log_softmax(
      logits, query_length_dim, extra_logit=extra_logit)
  doubly_weights = mtf.softmax(
      lower_log_weights, memory_length_dim, extra_logit=extra_logit)

  weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights
  weights = mtf.dropout(
      weights, context.train, 1.0 - dropout_rate,
      noise_shape=weights.shape - dropout_broadcast_dims)
  outputs_shape = q.shape - key_dim + value_dim
  outputs = mtf.einsum([weights, v], outputs_shape)
  return outputs
Ejemplo n.º 36
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

        # Declare all the dimensions
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        hidden_dim = mtf.Dimension("hidden", hparams.hidden_size)
        filter_h_dim = mtf.Dimension("filter_height", 7)
        filter_w_dim = mtf.Dimension("filter_width", 7)
        filters = mtf.Dimension("filters", hparams.filter_sizes[0])
        rows_dim = mtf.Dimension("rows_size", 32)
        cols_dim = mtf.Dimension("cols_size", 96)
        row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks)
        col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks)
        classes_dim = mtf.Dimension("classes", 10)
        one_channel_dim = mtf.Dimension("one_channel", 1)

        inputs = features["inputs"]
        x = mtf.import_tf_tensor(
            mesh,
            tf.reshape(inputs, [
                hparams.batch_size, hparams.row_blocks,
                hparams.rows_size // hparams.row_blocks, hparams.col_blocks,
                hparams.num_channels * hparams.cols_size // hparams.col_blocks,
                1
            ]),
            mtf.Shape([
                batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim,
                one_channel_dim
            ]))
        x = mtf.transpose(x, [
            batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim,
            one_channel_dim
        ])

        x = mtf.to_float(x)
        initial_filters = mtf.get_variable(
            mesh, "init_filters",
            mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters]))
        x = mtf.conv2d_with_blocks(x,
                                   initial_filters,
                                   strides=[1, 1, 1, 1],
                                   padding="SAME",
                                   h_blocks_dim=None,
                                   w_blocks_dim=col_blocks_dim)

        x = batch_norm_relu(x, is_training)

        # Conv blocks
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_layers):
            layer_name = "block_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Residual block layer
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[0],
                                blocks=hparams.layer_sizes[0],
                                strides=[1, 1, 1, 1],
                                is_training=is_training,
                                name="block_layer1",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[1],
                                blocks=hparams.layer_sizes[1],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer2",
                                row_blocks_dim=None,
                                col_blocks_dim=None)
                x = block_layer(inputs=x,
                                filters=hparams.filter_sizes[2],
                                blocks=hparams.layer_sizes[2],
                                strides=[1, 2, 2, 1],
                                is_training=is_training,
                                name="block_layer3",
                                row_blocks_dim=None,
                                col_blocks_dim=None)

        # Calculate the logits and loss.
        out = x
        outputs = mtf.layers.dense(out,
                                   hidden_dim,
                                   reduced_dims=out.shape.dims[-5:],
                                   activation=mtf.relu,
                                   name="dense")

        # We assume fixed vocab size for targets
        labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3])
        labels = mtf.import_tf_tensor(mesh,
                                      tf.reshape(labels, [hparams.batch_size]),
                                      mtf.Shape([batch_dim]))

        logits = mtf.layers.dense(outputs, classes_dim, name="logits")
        soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype)
        loss = mtf.layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, classes_dim)

        # Reshape logits so it doesn't break inside t2t.
        logits = mtf.reshape(
            logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim]))
        loss = mtf.reduce_mean(loss)
        return logits, loss