def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): if not self.is_training: return config.is_decoder = False config.lsh_num_chunks_after = 1 model = ReformerForMaskedLM(config=config) model.to(torch_device) model.train() loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"] loss.backward()
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels): if not self.is_training: return # disable dropout config.hidden_dropout_prob = 0 config.local_attention_probs_dropout_prob = 0 config.lsh_attention_probs_dropout_prob = 0 config.lsh_num_chunks_after = 1 config.is_decoder = False torch.manual_seed(0) model = ReformerForMaskedLM(config=config) model.to(torch_device) model.train() model.zero_grad() loss_no_chunk, output_no_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2] loss_no_chunk.backward() grad_slice_word_no_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] grad_slice_position_factor_1_no_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] grad_slice_position_factor_2_no_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] config.chunk_size_lm_head = 1 config.chunk_size_feed_forward = 1 torch.manual_seed(0) model = ReformerForMaskedLM(config=config) model.to(torch_device) model.train() model.zero_grad() loss_chunk, output_chunk = model(input_ids, labels=input_ids, attention_mask=input_mask)[:2] loss_chunk.backward() grad_slice_word_chunk = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] grad_slice_position_factor_1_chunk = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] grad_slice_position_factor_2_chunk = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] self.parent.assertTrue(torch.allclose(loss_chunk, loss_no_chunk, atol=1e-3)) self.parent.assertTrue(torch.allclose(grad_slice_word_no_chunk, grad_slice_word_chunk, atol=1e-3)) self.parent.assertTrue( torch.allclose(grad_slice_position_factor_1_chunk, grad_slice_position_factor_1_no_chunk, atol=1e-3) ) self.parent.assertTrue( torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3) )