def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels): model = ReformerModel(config=config) model.to(torch_device) model.eval() (sequence_output, ) = model(input_ids, attention_mask=input_mask) (sequence_output, ) = model(input_ids) result = { "sequence_output": sequence_output, } # 2 * hidden_size because we use reversible resnet layers self.parent.assertListEqual( list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], )
def create_and_check_reformer_model_with_attn_mask( self, config, input_ids, input_mask, choice_labels, is_decoder=False ): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder if self.lsh_attn_chunk_length is not None: # need to set chunk length equal sequence length to be certain that chunking works config.lsh_attn_chunk_length = self.seq_length model = ReformerModel(config=config) model.to(torch_device) model.eval() # set all position encodings to zero so that postions don't matter with torch.no_grad(): embedding = model.embeddings.position_embeddings.embedding embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device)) embedding.weight.requires_grad = False half_seq_len = self.seq_length // 2 roll = self.chunk_length half_input_ids = input_ids[:, :half_seq_len] # normal padded attn_mask = torch.cat( [torch.ones_like(half_input_ids), torch.zeros_like(half_input_ids)], dim=-1, ) input_ids_padded = torch.cat( [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, ) # shifted padded input_ids_roll = torch.cat( [half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)], dim=-1, ) input_ids_roll = torch.roll(input_ids_roll, roll, dims=-1) attn_mask_roll = torch.roll(attn_mask, roll, dims=-1) output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len] output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll : half_seq_len + roll] self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): torch.manual_seed(0) model = ReformerModel(config=config) model.to(torch_device) model.eval() hidden_states_no_chunk = model(input_ids, attention_mask=input_mask)[0] config.chunk_size_lm_head = 1 config.chunk_size_feed_forward = 1 torch.manual_seed(0) model = ReformerModel(config=config) model.to(torch_device) model.eval() hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))