def test_trainable_variables(self): r"""Tests the functionality of automatically collecting trainable variables. """ # case 1: xlnet base encoder = XLNetEncoder() self.assertEqual(len(encoder.trainable_variables), 182) _, _ = encoder(self.inputs) # Case 2: xlnet large hparams = { "pretrained_model_name": "xlnet-large-cased", } encoder = XLNetEncoder(hparams=hparams) self.assertEqual(len(encoder.trainable_variables), 362) _, _ = encoder(self.inputs) # case 3: self-designed bert hparams = { "num_layers": 6, "pretrained_model_name": None, } encoder = XLNetEncoder(hparams=hparams) self.assertEqual(len(encoder.trainable_variables), 92) _, _ = encoder(self.inputs)
def test_hparams(self): r"""Tests the priority of the encoder arch parameter. """ # case 1: set "pretrained_mode_name" by constructor argument hparams = { "pretrained_model_name": "xlnet-large-cased", } encoder = XLNetEncoder(pretrained_model_name="xlnet-base-cased", hparams=hparams) self.assertEqual(encoder.hparams.num_layers, 12) _, _ = encoder(self.inputs) # case 2: set "pretrained_mode_name" by hparams hparams = { "pretrained_model_name": "xlnet-large-cased", "num_layers": 6, } encoder = XLNetEncoder(hparams=hparams) self.assertEqual(encoder.hparams.num_layers, 24) _, _ = encoder(self.inputs) # case 3: set to None in both hparams and constructor argument hparams = { "pretrained_model_name": None, "num_layers": 6, } encoder = XLNetEncoder(hparams=hparams) self.assertEqual(encoder.hparams.num_layers, 6) _, _ = encoder(self.inputs) # case 4: using default hparams encoder = XLNetEncoder() self.assertEqual(encoder.hparams.num_layers, 12) _, _ = encoder(self.inputs)
def test_model_loading(self): r"""Tests model loading functionality.""" # case 1 encoder = XLNetEncoder(pretrained_model_name="xlnet-base-cased") _, _ = encoder(self.inputs) # case 2 encoder = XLNetEncoder(pretrained_model_name="xlnet-large-cased") _, _ = encoder(self.inputs)
def test_encode(self): r"""Tests encoding. """ # case 1: xlnet base hparams = { "pretrained_model_name": None, } encoder = XLNetEncoder(hparams=hparams) inputs = torch.randint(32000, (self.batch_size, self.max_length)) outputs, new_memory = encoder(inputs) self.assertEqual( outputs.shape, torch.Size([self.batch_size, self.max_length, encoder.output_size])) self.assertEqual(new_memory, None) # case 2: self-designed xlnet hparams = { 'pretrained_model_name': None, 'untie_r': True, 'num_layers': 6, 'mem_len': 0, 'reuse_len': 0, 'num_heads': 8, 'hidden_dim': 32, 'head_dim': 64, 'dropout': 0.1, 'attention_dropout': 0.1, 'use_segments': True, 'ffn_inner_dim': 256, 'activation': 'gelu', 'vocab_size': 32000, 'max_seq_length': 128, 'initializer': None, 'name': "xlnet_encoder", } encoder = XLNetEncoder(hparams=hparams) outputs, new_memory = encoder(inputs) self.assertEqual( outputs.shape, torch.Size([self.batch_size, self.max_length, encoder.output_size])) self.assertEqual(new_memory, None)
def __init__(self, pretrained_model_name: Optional[str] = None, cache_dir: Optional[str] = None, hparams=None): super().__init__(hparams=hparams) # Create the underlying encoder encoder_hparams = dict_fetch(hparams, XLNetEncoder.default_hparams()) self._encoder = XLNetEncoder( pretrained_model_name=pretrained_model_name, cache_dir=cache_dir, hparams=encoder_hparams) # TODO: The logic here is very similar to that in XLNetClassifier. # We need to reduce the code redundancy. if self._hparams.use_projection: if self._hparams.regr_strategy == 'all_time': self.projection = nn.Linear( self._encoder.output_size * self._hparams.max_seq_length, self._encoder.output_size * self._hparams.max_seq_length) else: self.projection = nn.Linear(self._encoder.output_size, self._encoder.output_size) self.dropout = nn.Dropout(self._hparams.dropout) logit_kwargs = self._hparams.logit_layer_kwargs if logit_kwargs is None: logit_kwargs = {} elif not isinstance(logit_kwargs, HParams): raise ValueError("hparams['logit_layer_kwargs'] " "must be a dict.") else: logit_kwargs = logit_kwargs.todict() if self._hparams.regr_strategy == 'all_time': self.hidden_to_logits = nn.Linear( self._encoder.output_size * self._hparams.max_seq_length, 1, **logit_kwargs) else: self.hidden_to_logits = nn.Linear( self._encoder.output_size, 1, **logit_kwargs) if self._hparams.initializer: initialize = get_initializer(self._hparams.initializer) assert initialize is not None if self._hparams.use_projection: initialize(self.projection.weight) initialize(self.projection.bias) initialize(self.hidden_to_logits.weight) if self.hidden_to_logits.bias: initialize(self.hidden_to_logits.bias) else: if self._hparams.use_projection: self.projection.apply(init_weights) self.hidden_to_logits.apply(init_weights)
def test_soft_ids(self): r"""Tests soft ids. """ hparams = { "pretrained_model_name": None, } encoder = XLNetEncoder(hparams=hparams) inputs = torch.rand(self.batch_size, self.max_length, 32000) outputs, new_memory = encoder(inputs) self.assertEqual( outputs.shape, torch.Size([self.batch_size, self.max_length, encoder.output_size])) self.assertEqual(new_memory, None)
def test_model_loading(self): r"""Tests model loading functionality.""" for pretrained_model_name in XLNetEncoder.available_checkpoints(): encoder = XLNetEncoder(pretrained_model_name=pretrained_model_name) _ = encoder(self.inputs)