def test_local_layer_forward_complex(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["local"] attn_mask = self._get_attn_mask() hidden_states = self._get_hidden_states() torch.manual_seed(0) layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states, attention_mask=attn_mask,) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( [1.5476, -1.9020, -0.9902, 1.5013, -0.1950], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_layer_forward(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["local"] config["is_decoder"] = False hidden_states = self._get_hidden_states() torch.manual_seed(0) layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer(prev_attn_output=hidden_states, hidden_states=hidden_states) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( [1.4212, -2.0576, -0.9688, 1.4599, -0.1344], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lsh_layer_forward(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["lsh"] config["is_decoder"] = False hidden_states = self._get_hidden_states() torch.manual_seed(0) layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer(prev_attn_output=hidden_states.clone(), hidden_states=hidden_states) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( [1.6879, -1.3083, -0.4708, 1.3555, -0.6292], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lsh_layer_forward_complex(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["lsh"] config["num_buckets"] = [2, 4] attn_mask = self._get_attn_mask() hidden_states = self._get_hidden_states() torch.manual_seed(0) layer = ReformerLayer(ReformerConfig(**config)).to(torch_device) layer.eval() reformer_output = layer( prev_attn_output=hidden_states.clone(), hidden_states=hidden_states, attention_mask=attn_mask, ) output_slice = reformer_output.hidden_states[0, 0, :5] expected_output_slice = torch.tensor( [1.6439, -1.2306, -0.5108, 1.3006, -0.6537], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))