def sequence_classifier_train_step(self, batch: Dict[str, torch.Tensor], use_logits: bool = False, temperature: float = 1, **_) -> torch.Tensor: """Perform a sequence classifier training step.""" inputs = self.generate_default_inputs(batch) if not use_logits: inputs['labels'] = batch['labels'] outputs = self.model(**inputs) if use_logits: logits_predicted, logits_target = outputs[0], batch['logits'] return distillation_loss(logits_predicted, logits_target, temperature) else: return outputs[0]
def mlm_train_step(self, labeled_batch: Dict[str, torch.Tensor], unlabeled_batch: Optional[Dict[str, torch.Tensor]] = None, lm_training: bool = False, alpha: float = 0, mlm_logits: bool = False, temperature: float = 1.0, **_) -> torch.Tensor: """Perform a MLM training step.""" inputs = self.generate_default_inputs(labeled_batch) mlm_labels, labels = labeled_batch['mlm_labels'], labeled_batch['labels'] outputs = self.model(**inputs) prediction_scores = self.preprocessor.pvp.convert_mlm_logits_to_cls_logits(mlm_labels, outputs[0]) if mlm_logits: loss = distillation_loss(prediction_scores, labeled_batch['logits'], temperature) else: loss = nn.CrossEntropyLoss()(prediction_scores.view(-1, len(self.config.label_list)), labels.view(-1)) if lm_training: lm_inputs = self.generate_default_inputs(unlabeled_batch) lm_inputs['masked_lm_labels'] = unlabeled_batch['mlm_labels'] lm_loss = self.model(**lm_inputs)[0] loss = alpha * loss + (1 - alpha) * lm_loss return loss
def evaluate(model: TransformerModelWrapper, eval_data: List[InputExample], config: EvalConfig, priming_data: List[InputExample] = None) -> Dict: """ Evaluate a model. :param model: the model to evaluate :param eval_data: the examples for evaluation :param config: the evaluation config :param priming_data: an optional list of priming data to use :return: a dictionary containing the model's logits, predictions and (if any metrics are given) scores """ if config.priming: for example in eval_data: example.meta['priming_data'] = priming_data metrics = config.metrics if config.metrics else ['acc'] device = torch.device(config.device if config.device else "cuda" if torch. cuda.is_available() else "cpu") model.model.to(device) results = model.eval( eval_data, device, per_gpu_eval_batch_size=config.per_gpu_eval_batch_size, n_gpu=config.n_gpu, decoding_strategy=config.decoding_strategy, priming=config.priming) predictions = np.argmax(results['logits'], axis=1) scores = {} for metric in metrics: if metric == 'acc': scores[metric] = simple_accuracy(predictions, results['labels']) elif metric == 'f1': scores[metric] = f1_score(results['labels'], predictions) elif metric == 'f1-macro': scores[metric] = f1_score(results['labels'], predictions, average='macro') elif metric == 'em': scores[metric] = exact_match(predictions, results['labels'], results['question_ids']) elif metric == 'dist-loss': if eval_data[0].logits is not None: scores[metric] = distillation_loss( torch.tensor(results['logits']), torch.stack([ torch.tensor(ex.logits, dtype=torch.float32) for ex in eval_data ]), config.temperature) else: scores[metric] = 0. else: raise ValueError(f"Metric '{metric}' not implemented") results['scores'] = scores results['predictions'] = predictions return results