Ejemplo n.º 1
0
def _RelPositionBias(query, abs_pos_emb):
  """Computes relative position bias for general cases."""
  _, t, n, h = py_utils.GetShape(query)
  abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

  # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
  # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)]
  abs_pos_emb = tf.reverse(abs_pos_emb, [0])

  # [B, N, T, L=2T-1]
  term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

  # Convert to [B, N, T, T]
  # part1
  term_bd_left = term_bd[:, :, :, :t]
  term_bd_left = tf.reverse(term_bd_left, [2, 3])
  term_bd_left = RelShift(term_bd_left)
  # [B, N, T, T]
  term_bd_left = tf.reverse(term_bd_left, [2, 3])
  # part 2
  term_bd_right = term_bd[:, :, :, t - 1:]
  # [B, N, T, T]
  term_bd_right = RelShift(term_bd_right)
  # [lower triangle]
  mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0)

  # stitching togather
  return tf.where(mask > 0, term_bd_left, term_bd_right)
Ejemplo n.º 2
0
 def _GetExpertDist(self, theta, inputs, *args):
   """Get the task id from inputs tensors."""
   # TODO(huangyp): support the more general case when batch size is not 1.
   # Input shape can be either [batch, length, dim] or [length, batch, dim]
   reshaped_inputs = tf.reshape(inputs, [-1, self.params.cond_dim])
   if self.params.nonzeros_mean:
     per_example_emb = tf.reduce_sum(reshaped_inputs, 0)
     nonzeros = tf.cast(
         tf.math.count_nonzero(reshaped_inputs, 0), dtype=tf.float32)
     per_example_emb /= (nonzeros + 1e-10)
   else:
     per_example_emb = tf.reduce_mean(reshaped_inputs, 0)
   expert_dist = tf.nn.sigmoid(tf.einsum('i,ij->j', per_example_emb, theta.w))
   return expert_dist
Ejemplo n.º 3
0
def _AttenLogits(query,
                 key,
                 abs_pos_emb,
                 content_bias=None,
                 positional_bias=None,
                 is_causal=False):
  """Attention logits from ...

  Transformer-XL(https://arxiv.org/pdf/1901.02860.pdf, section 3.3) version of
  self attention with relative position embedding.

  Notice padding is supposed to be masked by the caller of this function.

  B: batch size
  T: sequence length
  N: num of attention heads.
  H: per-head attention dimension.

  Args:
    tensors of the following shapes:
    query:           [B, T, N, H]
    key:             [B, T, N, H]
    abs_pos_emb:     [2T - 1, N, H]. The sinusoid positional embedding from
    https://arxiv.org/abs/1706.03762. abs_pos_emb[i] is the emb of relative
    distance i - (T-1).
    content_bias:    [N, H] or None
    positional_bias: [N, H] or None
    is_causal: A Python bool or a scalar bool Tensor. True for causal self
    attention.

  Returns:
    The attention logits tensor. [B, N, T, T]
  """
  b, t, n, h = py_utils.GetShape(query)

  key = py_utils.HasShape(key, [b, t, n, h])
  if content_bias is not None:
    content_bias = py_utils.HasShape(content_bias, [n, h])
  else:
    content_bias = 0
  if positional_bias is not None:
    positional_bias = py_utils.HasShape(positional_bias, [n, h])
  else:
    positional_bias = 0

  # [B, N, T, S=T]
  term_ac = tf.einsum('BTNH,BSNH->BNTS', query + content_bias, key)
  term_bd = RelPositionBias(query + positional_bias, abs_pos_emb, is_causal)
  return term_ac + term_bd
Ejemplo n.º 4
0
  def FProp(self, theta, inputs, *args):
    p = self.params
    with tf.name_scope(p.name) as scope:
      expert_dist = self._GetExpertDist(theta, inputs, *args)
      if not self.do_eval:
        summary_utils.histogram('soft_cond_{}'.format(scope), expert_dist)

      # Excludes non-variable extra_theta like global_step.
      var_set = set([key for key, _ in self.body.vars.FlattenItems()])
      values = []
      for key, value in theta.body.FlattenItems():
        if key in var_set and value is not None:
          # Weighted average for all variables created in the body layer.
          value = tf.einsum('i,i...->...', expert_dist, value)
        values.append(value)
      weighted_theta = theta.body.Pack(values)
      return self.body.FProp(weighted_theta, inputs, *args)
Ejemplo n.º 5
0
def _RelPositionBiasCausal(query, abs_pos_emb):
  """Computes relative position bias for causal self attention."""
  _, t, n, h = py_utils.GetShape(query)

  abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h])

  # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1]
  # Retain only half and change order to [T-1, T-2, ... 0]
  # [T, N, H]
  abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t]

  # [B, N, T, L=T]
  term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb)

  # Perform shifting.
  term_bd = tf.reverse(term_bd, [2, 3])
  term_bd = RelShift(term_bd)
  return tf.reverse(term_bd, [2, 3])
Ejemplo n.º 6
0
    def ProjectInputSequence(self, theta, inputs):
        """Applies input projection for the entire sequence.

    Args:
      theta: a NestedMap of layer weights. Notably, it's expected to contain
        separate weight tensors for input and hidden state projections, for
        performance reasons, under the key 'wm_i' (input) and 'wm_h'
      inputs: A NestedMap with the following fields:
        - act: A list of Tensors of shape [seqlen, batch, input_dim].

    Returns:
      A Tensor of shape [seqlen, batch, 4 * hidden_dim].
    """
        assert isinstance(inputs.act, list)
        if len(inputs.act) > 1:
            x = tf.concat(inputs.act, -1)
        else:
            x = inputs.act[0]
        # [T, B, 4 * H]
        proj_inputs = tf.einsum('TBD,DH->TBH', x, theta.wm_i)
        return proj_inputs
 def EinsumBxycBzxBzyc(self, a, b, name=None):
     return tf.einsum('bxyc,bzx->bzyc', a, b, name=name)
 def EinsumBxyBxBxy(self, a, b, name=None):
     return tf.einsum('bxy,bx->bxy', a, b, name=name)
 def EinsumBxycBxBxyc(self, a, b, name=None):
     return tf.einsum('bxyc,bx->bxyc', a, b, name=name)
 def EinsumBmtBmBt(self, a, b, name=None):
     return tf.einsum('bmt,bm->bt', a, b, name=name)
 def EinsumBBmBm(self, a, b, name=None):
     return tf.einsum('b,bm->bm', a, b, name=name)