Exemplo n.º 1
0
    def create_and_check_reformer_model_with_attn_mask(self, config, input_ids,
                                                       input_mask, is_decoder):
        # no special position embeddings
        config.axial_pos_embds = False
        config.is_decoder = is_decoder

        if self.lsh_attn_chunk_length is not None:
            # need to set chunk length equal sequence length to be certain that chunking works
            config.lsh_attn_chunk_length = self.seq_length

        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()
        # set all position encodings to zero so that postions don't matter
        with torch.no_grad():
            embedding = model.embeddings.position_embeddings.embedding
            embedding.weight = torch.nn.Parameter(
                torch.zeros(embedding.weight.shape).to(torch_device))
            embedding.weight.requires_grad = False

        half_seq_len = self.seq_length // 2
        roll = self.chunk_length

        half_input_ids = input_ids[:, :half_seq_len]

        # normal padded
        attn_mask = torch.cat(
            [
                torch.ones_like(half_input_ids),
                torch.zeros_like(half_input_ids)
            ],
            dim=-1,
        )
        input_ids_padded = torch.cat(
            [
                half_input_ids,
                ids_tensor((self.batch_size, half_seq_len), self.vocab_size)
            ],
            dim=-1,
        )

        # shifted padded
        input_ids_roll = torch.cat(
            [
                half_input_ids,
                ids_tensor((self.batch_size, half_seq_len), self.vocab_size)
            ],
            dim=-1,
        )
        input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1)
        attn_mask_roll = torch.roll(attn_mask, roll, dims=-1)

        output_padded = model(input_ids_padded,
                              attention_mask=attn_mask)[0][:, :half_seq_len]
        output_padded_rolled = model(
            input_ids_roll,
            attention_mask=attn_mask_roll)[0][:, roll:half_seq_len + roll]

        self.parent.assertTrue(
            torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
Exemplo n.º 2
0
 def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, choice_labels):
     model = ReformerModel(config=config)
     model.to(torch_device)
     model.half()
     model.eval()
     output = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
     self.parent.assertFalse(torch.isnan(output).any().item())
Exemplo n.º 3
0
    def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()
        result = model(input_ids, attention_mask=input_mask)
        result = model(input_ids)

        # 2 * hidden_size because we use reversible resnet layers
        self.parent.assertEqual(
            result.last_hidden_state.shape, (self.batch_size, self.seq_length, 2 * self.hidden_size)
        )
 def test_local_model_forward(self):
     config = self._get_basic_config_and_input()
     config["attn_layers"] = ["local", "local", "local", "local"]
     torch.manual_seed(0)
     model = ReformerModel(ReformerConfig(**config)).to(torch_device)
     model.eval()
     input_ids, attn_mask = self._get_input_ids_and_mask()
     hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
     output_slice = hidden_states[0, 0, :5]
     expected_output_slice = torch.tensor(
         [-1.6791, 0.7171, 0.1594, 0.4063, 1.2584], dtype=torch.float, device=torch_device,
     )
     self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
 def test_lsh_model_forward(self):
     config = self._get_basic_config_and_input()
     config["attn_layers"] = ["lsh", "lsh", "lsh", "lsh"]
     config["num_buckets"] = [2, 4]
     torch.manual_seed(0)
     model = ReformerModel(ReformerConfig(**config)).to(torch_device)
     model.eval()
     input_ids, attn_mask = self._get_input_ids_and_mask()
     hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
     output_slice = hidden_states[0, 0, :5]
     expected_output_slice = torch.tensor(
         [-0.9896, -0.9396, -1.0831, -0.0597, 0.2456], dtype=torch.float, device=torch_device,
     )
     self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
Exemplo n.º 6
0
    def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()
        sequence_output, _ = model(input_ids, attention_mask=input_mask)
        sequence_output, _ = model(input_ids)

        result = {
            "sequence_output": sequence_output,
        }
        # 2 * hidden_size because we use reversible resnet layers
        self.parent.assertListEqual(
            list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
        )
    def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask):
        torch.manual_seed(0)
        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()
        hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]

        config.chunk_size_lm_head = 1
        config.chunk_size_feed_forward = 1

        torch.manual_seed(0)
        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()

        hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
        self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))