def create_and_check_reformer_random_seed(self, config, input_ids, input_mask, choice_labels): layer = ReformerLayer(config).to(torch_device) layer.train() shape = ( self.batch_size, self.seq_length, config.hidden_size, ) # Batch x SeqLen x hiddenSize hidden_states = floats_tensor(shape) attn_output = floats_tensor(shape) seeds = [] for _ in range(100): layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.attention_seed) seeds.append(layer.attention_seed) self.parent.assertGreater(len(set(seeds)), 70) seeds = [] for _ in range(100): layer_outputs = layer(attn_output, hidden_states, attention_mask=input_mask) attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.feed_forward_seed) seeds.append(layer.feed_forward_seed) self.parent.assertGreater(len(set(seeds)), 70)
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): config.is_decoder = is_decoder layer = ReformerLayer(config).to(torch_device) layer.train() shape = ( self.batch_size, self.seq_length, config.hidden_size, ) # Batch x SeqLen x hiddenSize # get random tensors hidden_states = floats_tensor(shape) prev_attn_output = floats_tensor(shape) # now the random seeds for attention and feed forward is initialized # forward tensors with dropout layer_outputs = layer(prev_attn_output, hidden_states, attention_mask=input_mask) next_attn_output = layer_outputs.attn_output next_hidden_states = layer_outputs.hidden_states torch.manual_seed(layer.attention_seed) attn_outputs = layer.attention(hidden_states, attention_mask=input_mask) self.parent.assertTrue( torch.allclose(prev_attn_output + attn_outputs.hidden_states, next_attn_output, atol=1e-3,) ) torch.manual_seed(layer.feed_forward_seed) feed_forward_hidden_states = layer.feed_forward(next_attn_output) self.parent.assertTrue( torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) )