def add_pos_signals(x, hparams, name="pos_emb"):
    with tf.variable_scope(name, reuse=False):
        if hparams.pos == "timing":
            x = common_attention.add_timing_signal_nd(x)
        else:
            assert hparams.pos == "emb"
            x = common_attention.add_positional_embedding_nd(
                x, hparams.max_length, name=name)
    return x
def add_pos_signals(x, hparams, name="pos_emb"):
  with tf.variable_scope(name, reuse=False):
    if hparams.pos == "timing":
      x = common_attention.add_timing_signal_nd(x)
    else:
      assert hparams.pos == "emb"
      x = common_attention.add_positional_embedding_nd(
          x, hparams.max_length, name)
  return x
Beispiel #3
0
 def testAddPositionalEmbeddingNd(self, input_shape):
   x = np.random.rand(*input_shape)
   y = common_attention.add_positional_embedding_nd(
       tf.constant(x, dtype=tf.float32),
       max_length=5,
       name="pos_embedding")
   self.evaluate(tf.global_variables_initializer())
   res = self.evaluate(y)
   self.assertEqual(res.shape, input_shape)