def test_local_layer_forward_complex(self):
     config = self._get_basic_config_and_input()
     config["attn_layers"] = ["local"]
     attn_mask = self._get_attn_mask()
     hidden_states = self._get_hidden_states()
     torch.manual_seed(0)
     layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
     layer.eval()
     reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,)
     output_slice = reformer_output.hidden_states[0, 0, :5]
     expected_output_slice = torch.tensor(
         [1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device,
     )
     self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
 def test_local_layer_forward(self):
     config = self._get_basic_config_and_input()
     config["attn_layers"] = ["local"]
     config["is_decoder"] = False
     hidden_states = self._get_hidden_states()
     torch.manual_seed(0)
     layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
     layer.eval()
     reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states)
     output_slice = reformer_output.hidden_states[0, 0, :5]
     expected_output_slice = torch.tensor(
         [1.4212, -2.0576, -0.9688, 1.4599, -0.1344], dtype=torch.float, device=torch_device,
     )
     self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
 def test_lsh_layer_forward(self):
     config = self._get_basic_config_and_input()
     config["attn_layers"] = ["lsh"]
     config["is_decoder"] = False
     hidden_states = self._get_hidden_states()
     torch.manual_seed(0)
     layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
     layer.eval()
     reformer_output = layer(prev_attn_output=hidden_states.clone(), hidden_states=hidden_states)
     output_slice = reformer_output.hidden_states[0, 0, :5]
     expected_output_slice = torch.tensor(
         [1.6879, -1.3083, -0.4708, 1.3555, -0.6292], dtype=torch.float, device=torch_device,
     )
     self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
 def test_lsh_layer_forward_complex(self):
     config = self._get_basic_config_and_input()
     config["attn_layers"] = ["lsh"]
     config["num_buckets"] = [2, 4]
     attn_mask = self._get_attn_mask()
     hidden_states = self._get_hidden_states()
     torch.manual_seed(0)
     layer = ReformerLayer(ReformerConfig(**config)).to(torch_device)
     layer.eval()
     reformer_output = layer(
         prev_attn_output=hidden_states.clone(), hidden_states=hidden_states, attention_mask=attn_mask,
     )
     output_slice = reformer_output.hidden_states[0, 0, :5]
     expected_output_slice = torch.tensor(
         [1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device,
     )
     self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))