def check_seq_classifier_loss(self, config, input_values, *args): model = SEWDForSequenceClassification(config=config) model.to(torch_device) # make sure that dropout is disabled model.eval() input_values = input_values[:3] attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) # pad input for i in range(len(input_lengths)): input_values[i, input_lengths[i]:] = 0.0 attention_mask[i, input_lengths[i]:] = 0 masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() unmasked_loss = model(input_values, labels=labels).loss.item() self.parent.assertTrue(isinstance(masked_loss, float)) self.parent.assertTrue(isinstance(unmasked_loss, float)) self.parent.assertTrue(masked_loss != unmasked_loss)
def check_seq_classifier_training(self, config, input_values, *args): config.ctc_zero_infinity = True model = SEWDForSequenceClassification(config=config) model.to(torch_device) model.train() # freeze everything but the classification head model.freeze_base_model() input_values = input_values[:3] input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) # pad input for i in range(len(input_lengths)): input_values[i, input_lengths[i] :] = 0.0 loss = model(input_values, labels=labels).loss self.parent.assertFalse(torch.isinf(loss).item()) loss.backward()