Example #1
0
    def create_and_check_reformer_model(self, config, input_ids, input_mask,
                                        choice_labels):
        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()
        (sequence_output, ) = model(input_ids, attention_mask=input_mask)
        (sequence_output, ) = model(input_ids)

        result = {
            "sequence_output": sequence_output,
        }
        # 2 * hidden_size because we use reversible resnet layers
        self.parent.assertListEqual(
            list(result["sequence_output"].size()),
            [self.batch_size, self.seq_length, 2 * self.hidden_size],
        )
Example #2
0
    def create_and_check_reformer_model_with_attn_mask(
        self, config, input_ids, input_mask, choice_labels, is_decoder=False
    ):
        # no special position embeddings
        config.axial_pos_embds = False
        config.is_decoder = is_decoder

        if self.lsh_attn_chunk_length is not None:
            # need to set chunk length equal sequence length to be certain that chunking works
            config.lsh_attn_chunk_length = self.seq_length

        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()
        # set all position encodings to zero so that postions don't matter
        with torch.no_grad():
            embedding = model.embeddings.position_embeddings.embedding
            embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
            embedding.weight.requires_grad = False

        half_seq_len = self.seq_length // 2
        roll = self.chunk_length

        half_input_ids = input_ids[:, :half_seq_len]

        # normal padded
        attn_mask = torch.cat(
            [torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)],
            dim=-1,
        )
        input_ids_padded = torch.cat(
            [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)],
            dim=-1,
        )

        # shifted padded
        input_ids_roll = torch.cat(
            [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)],
            dim=-1,
        )
        input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1)
        attn_mask_roll = torch.roll(attn_mask, roll, dims=-1)

        output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len]
        output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll : half_seq_len + roll]

        self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
    def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask):
        torch.manual_seed(0)
        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()
        hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0]

        config.chunk_size_lm_head = 1
        config.chunk_size_feed_forward = 1

        torch.manual_seed(0)
        model = ReformerModel(config=config)
        model.to(torch_device)
        model.eval()

        hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
        self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))