def BERTRegressionHead(): return tl.Serial([ tl.Select([0], n_in=2), tl.Dense(1, kernel_initializer=tl.RandomNormalInitializer(0.02), bias_initializer=tl.RandomNormalInitializer(1e-6), ), ])
def BERTClassifierHead(n_classes): return tl.Serial([ tl.Select([0], n_in=2), tl.Dense(n_classes, kernel_initializer=tl.RandomNormalInitializer(0.02), bias_initializer=tl.RandomNormalInitializer(1e-6), ), tl.LogSoftmax(), ])
def BERTMLMHead(vocab_size=30522): return tl.Serial([ tl.Select([1], n_in=2), tl.Dense( vocab_size, kernel_initializer=tl.RandomNormalInitializer(0.02), bias_initializer=tl.RandomNormalInitializer(1e-6), ), ])
def BERTRegressionHead(): return tl.Serial([ tl.Select([0], n_in=2), tl.Dense( 1, kernel_initializer=tl.RandomNormalInitializer(0.02), bias_initializer=tl.RandomNormalInitializer(1e-6), ), tl.Fn('RemoveAxis', lambda x: np.squeeze(x, axis=1)) ])
def test_random_normal(self): f = tl.RandomNormalInitializer() init_value = f(INPUT_SHAPE, rng()) self.assertEqual(init_value.shape, INPUT_SHAPE)