def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder if self.lsh_attn_chunk_length is not None: # need to set chunk length equal sequence length to be certain that chunking works config.lsh_attn_chunk_length = self.seq_length model = ReformerModel(config=config) model.to(torch_device) model.eval() # set all position encodings to zero so that postions don't matter with torch.no_grad(): embedding = model.embeddings.position_embeddings.embedding embedding.weight = torch.nn.Parameter( torch.zeros(embedding.weight.shape).to(torch_device)) embedding.weight.requires_grad = False half_seq_len = self.seq_length // 2 roll = self.chunk_length half_input_ids = input_ids[:, :half_seq_len] # normal padded attn_mask = torch.cat( [ torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids) ], dim=-1, ) input_ids_padded = torch.cat( [ half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size) ], dim=-1, ) # shifted padded input_ids_roll = torch.cat( [ half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size) ], dim=-1, ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] output_padded_rolled = model( input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll:half_seq_len + roll] self.parent.assertTrue( torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, choice_labels): model = ReformerModel(config=config) model.to(torch_device) model.half() model.eval() output = model(input_ids, attention_mask=input_mask)["last_hidden_state"] self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels): model = ReformerModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask) result = model(input_ids) # 2 * hidden_size because we use reversible resnet layers self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.seq_length, 2 * self.hidden_size) )
def test_local_model_forward(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["local", "local", "local", "local"] torch.manual_seed(0) model = ReformerModel(ReformerConfig(**config)).to(torch_device) model.eval() input_ids, attn_mask = self._get_input_ids_and_mask() hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] output_slice = hidden_states[0, 0, :5] expected_output_slice = torch.tensor( [-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_lsh_model_forward(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"] config["num_buckets"] = [2, 4] torch.manual_seed(0) model = ReformerModel(ReformerConfig(**config)).to(torch_device) model.eval() input_ids, attn_mask = self._get_input_ids_and_mask() hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] output_slice = hidden_states[0, 0, :5] expected_output_slice = torch.tensor( [-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels): model = ReformerModel(config=config) model.to(torch_device) model.eval() sequence_output, _ = model(input_ids, attention_mask=input_mask) sequence_output, _ = model(input_ids) result = { "sequence_output": sequence_output, } # 2 * hidden_size because we use reversible resnet layers self.parent.assertListEqual( list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], )
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): torch.manual_seed(0) model = ReformerModel(config=config) model.to(torch_device) model.eval() hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] config.chunk_size_lm_head = 1 config.chunk_size_feed_forward = 1 torch.manual_seed(0) model = ReformerModel(config=config) model.to(torch_device) model.eval() hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))