def test_local_lm_model_grad(self): config = self._get_basic_config_and_input() config["attn_layers"] = ["local", "local", "local", "local"] config["hidden_dropout_prob"] = 0.0 config["local_attention_probs_dropout_prob"] = 0.0 torch.manual_seed(0) model = ReformerModelWithLMHead(ReformerConfig(**config)).to(torch_device) model.train() model.zero_grad() input_ids, _ = self._get_input_ids_and_mask() loss = model(input_ids=input_ids, labels=input_ids)[0] self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3)) loss.backward() # check last grads to cover all proable errors grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] expected_grad_slice_word = torch.tensor( [-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], dtype=torch.float, device=torch_device, ) grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] expected_grad_slice_pos_fac_1 = torch.tensor( [0.0037, -1.3793, -1.0231, -1.5230, -2.5306], dtype=torch.float, device=torch_device, ) grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] expected_grad_slice_pos_fac_2 = torch.tensor( [-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], dtype=torch.float, device=torch_device, ) self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3)) self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3)) self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3))
from collections import OrderedDict import json all_training_loss = OrderedDict() all_val_loss = OrderedDict() for x in range(1): print(f"epoch {x}") start = time.time() training_loss = OrderedDict() val_loss = OrderedDict() for i in range(NUM_BATCHES): print("step {}".format(i)) model.train() tmp = next(train_loader) input_ids = tmp['input_ids'] attention_mask = tmp['attention_mask'] labels = tmp['labels'] outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss, prediction_scores = outputs[:2] loss.backward() training_loss[f"Epoch {x} Step {i}"] = loss.item() all_training_loss[f"Epoch {x} Step {i}"] = loss.item() print(f'training loss: {loss.item()}')