Exemple #1
0
def BERTRegressionHead():
  return tl.Serial([
      tl.Select([0], n_in=2),
      tl.Dense(1,
               kernel_initializer=tl.RandomNormalInitializer(0.02),
               bias_initializer=tl.RandomNormalInitializer(1e-6),
              ),
  ])
Exemple #2
0
def BERTClassifierHead(n_classes):
  return tl.Serial([
      tl.Select([0], n_in=2),
      tl.Dense(n_classes,
               kernel_initializer=tl.RandomNormalInitializer(0.02),
               bias_initializer=tl.RandomNormalInitializer(1e-6),
              ),
      tl.LogSoftmax(),
  ])
Exemple #3
0
def BERTMLMHead(vocab_size=30522):
    return tl.Serial([
        tl.Select([1], n_in=2),
        tl.Dense(
            vocab_size,
            kernel_initializer=tl.RandomNormalInitializer(0.02),
            bias_initializer=tl.RandomNormalInitializer(1e-6),
        ),
    ])
Exemple #4
0
def BERTRegressionHead():
    return tl.Serial([
        tl.Select([0], n_in=2),
        tl.Dense(
            1,
            kernel_initializer=tl.RandomNormalInitializer(0.02),
            bias_initializer=tl.RandomNormalInitializer(1e-6),
        ),
        tl.Fn('RemoveAxis', lambda x: np.squeeze(x, axis=1))
    ])
Exemple #5
0
 def test_random_normal(self):
   f = tl.RandomNormalInitializer()
   init_value = f(INPUT_SHAPE, rng())
   self.assertEqual(init_value.shape, INPUT_SHAPE)