def _get_decoder(self, units, vocab_size, embed, prefix): """ Construct a decoder for the masked language model task """ with self.name_scope(): decoder = nn.HybridSequential(prefix=prefix) decoder.add(nn.Dense(units, flatten=False)) decoder.add(GELU()) decoder.add(BERTLayerNorm(in_channels=units)) decoder.add(nn.Dense(vocab_size, flatten=False, params=embed.collect_params())) assert decoder[3].weight == list(embed.collect_params().values())[0], \ 'The weights of word embedding are not tied with those of decoder' return decoder
def __init__(self, units, hidden_size, weight_initializer=mx.init.Normal(0.02), bias_initializer='zeros', prefix=None, params=None): super(GPT2FFNLayer, self).__init__(prefix=prefix, params=params) self._units = units self._hidden_size = hidden_size with self.name_scope(): self._hidden_map = nn.Dense(flatten=False, units=hidden_size, weight_initializer=weight_initializer, bias_initializer=bias_initializer) self._out_map = nn.Dense(flatten=False, units=units, weight_initializer=weight_initializer, bias_initializer=bias_initializer) self._act = GELU()