Esempio n. 1
0
def test_v1_4_0_deprecated_metrics():
    from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
    with pytest.deprecated_call(match='will be removed in v1.4'):
        stat_scores_multiple_classes(pred=torch.tensor([0, 1]),
                                     target=torch.tensor([0, 1]))

    from pytorch_lightning.metrics.functional.classification import iou
    with pytest.deprecated_call(match='will be removed in v1.4'):
        iou(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        recall(torch.randint(0, 2, (10, 3, 3)),
               torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import precision
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision(torch.randint(0, 2, (10, 3, 3)),
                  torch.randint(0, 2, (10, 3, 3)))

    from pytorch_lightning.metrics.functional.classification import precision_recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision_recall(torch.randint(0, 2, (10, 3, 3)),
                         torch.randint(0, 2, (10, 3, 3)))

    # Testing deprecation of class_reduction arg in the *new* precision
    from pytorch_lightning.metrics.functional import precision
    with pytest.deprecated_call(match='will be removed in v1.4'):
        precision(torch.randint(0, 2, (10, )),
                  torch.randint(0, 2, (10, )),
                  class_reduction='micro')

    # Testing deprecation of class_reduction arg in the *new* recall
    from pytorch_lightning.metrics.functional import recall
    with pytest.deprecated_call(match='will be removed in v1.4'):
        recall(torch.randint(0, 2, (10, )),
               torch.randint(0, 2, (10, )),
               class_reduction='micro')

    from pytorch_lightning.metrics.functional.classification import auc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auc(torch.rand(10, ).sort().values, torch.rand(10, ))

    from pytorch_lightning.metrics.functional.classification import auroc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auroc(torch.rand(10, ), torch.randint(0, 2, (10, )))

    from pytorch_lightning.metrics.functional.classification import multiclass_auroc
    with pytest.deprecated_call(match='will be removed in v1.4'):
        multiclass_auroc(torch.rand(20, 5).softmax(dim=-1),
                         torch.randint(0, 5, (20, )),
                         num_classes=5)

    from pytorch_lightning.metrics.functional.classification import auc_decorator
    with pytest.deprecated_call(match='will be removed in v1.4'):
        auc_decorator()

    from pytorch_lightning.metrics.functional.classification import multiclass_auc_decorator
    with pytest.deprecated_call(match='will be removed in v1.4'):
        multiclass_auc_decorator()
Esempio n. 2
0
def group_aucs(predictions: torch.Tensor, targets: torch.Tensor, memberships: torch.Tensor) -> Dict[int, float]:
    """Computes the AUC for each protected group.

    Args:
        predictions: Tensor of shape (n_samples, ) with prediction logits.
        targets: Tensor of shape (n_samples, ) with ground truth (0 or 1).
        memberships: Tensor of shape (n_samples, ) with group membership indices.

    Returns:
        A dict mapping group indices as keys to the corresponding AUC values.        
    """
    
    groups = memberships.unique().to(predictions.device)
    groups = groups.to(predictions.device)
    targets = targets.to(predictions.device)
    memberships = memberships.to(predictions.device)
    aucs: Dict[int, float] = {}
    
    for group in groups:
        indices = (memberships == group)
        if torch.sum(targets[indices]) == 0 or torch.sum(1-targets[indices]) == 0:
            aucs[int(group)] = 0 
        else:
            aucs[int(group)] = auroc(predictions[indices], targets[indices]).item()
    return aucs
Esempio n. 3
0
def get_metric(network, loader, weights, device, metric):
    if weights is not None:
        raise ArgumentError("Non-uniform weights not supported")

    network.eval()
    ys = []
    ps = []
    sigmoid = nn.Sigmoid()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            ys.extend(y)
            ps.extend(sigmoid(network.predict(x)))
        ps = torch.stack(ps).to(device)
        ys = torch.stack(ys).to(device)
        if metric == 'micro_f1':
            result = f1_score(ps,
                              ys,
                              num_classes=None,
                              class_reduction='micro')
        elif metric == 'macro_f1':
            result = f1_score(ps,
                              ys,
                              num_classes=None,
                              class_reduction='macro')
        elif metric == 'auroc':
            result = []
            for d in range(ps.size(1)):
                result.append(auroc(ps[:, d], ys[:, d]))
    network.train()
    return result
Esempio n. 4
0
def run_folds(config: Dict[str, Any],
              args: argparse.Namespace,
              dataset: FairnessDataset,
              fold_indices: List[Tuple[np.ndarray, np.ndarray]],
              version: Optional[str] = None) -> float:
    """Runs kfold cross validation on the given dataset.
    
    Executes single training runs on the training set of each fold and evaluates
    the trained model on the validation set of the same fold.
    
    Args:
        config: Dict with hyperparameters (learning rate, batch size, eta).
        args: Object from the argument parser that defines various settings of
            the model, dataset and training.
        dataset: Dataset instance that will be used for cross validation.
        fold_indices: Indices to select training and validation subsets for each
            fold.
        version: Optional; version used for the logging directory.
    
    Returns:
        Mean of the micro-average AUC of the trained models on the validation
        sets of the models' corresponding folds.
    """

    print(
        f'Starting run with seed {args.seed} - lr {config["lr"]} - sec_lr {config["sec_lr"]} - bs {config["batch_size"]}'
    )

    fold_nbr = 0
    aucs: List[float] = []
    for train_idcs, val_idcs in fold_indices:
        fold_nbr += 1

        # create datasets for fold
        train_dataset = CustomSubset(dataset, train_idcs)
        val_dataset = CustomSubset(dataset, val_idcs)

        # train model
        t: Tuple[pl.LightningModule,
                 pl.Trainer] = train(config,
                                     args,
                                     train_dataset=train_dataset,
                                     val_dataset=val_dataset,
                                     version=args.version,
                                     fold_nbr=fold_nbr)
        model, _ = t

        # Evaluate on val set to get an estimate of performance
        scores: torch.Tensor = torch.sigmoid(model(val_dataset.features))
        aucs.append(auroc(scores, val_dataset.labels).item())

    mean_auc: float = np.mean(aucs)
    print(
        f'Finished run with seed {args.seed} - lr {config["lr"]} - sec_lr {config["sec_lr"]} - bs {config["batch_size"]} - mean val auc:'
        f' {mean_auc}')

    tune.report(auc=mean_auc)

    return mean_auc
Esempio n. 5
0
 def val_test_epoch_end(self, outputs, log_as="val"):
     y_pred = torch.cat([out[0] for out in outputs], dim=0)
     y = torch.cat([out[1] for out in outputs], dim=0)
     auc = auroc(y_pred, y)
     if log_as == "val":
         self.log(f"avg_{log_as}_auc", auc, prog_bar=True)
     else:
         self.log(f"avg_{log_as}_auc", auc)
Esempio n. 6
0
    def _step(self, batch, hiddens=None, calculate_roc=False):
        actual = batch["y"].clone()  # (batch, seq)

        ans_prev_correctly = batch["y"].clone()
        ans_prev_correctly = torch.roll(ans_prev_correctly, 1, dims=1)
        ans_prev_correctly[:, 0] = -1
        ans_prev_correctly = torch.unsqueeze(ans_prev_correctly, dim=2)

        seq_len_mask = batch["seq_len_mask"]  # (batch, seq)
        question_mask = batch["question_mask"]  # (batch, seq)

        batch_size, seq_len = actual.shape
        actual = actual.view(batch_size * seq_len).float()
        actual[actual < 0] = 0

        content_id = batch["content_id"]
        bundle_id = batch["bundle_id"]
        feature = batch["feature"]
        user_id = batch["user_id"]

        # assert torch.isnan(content_id).sum() == 0
        # assert torch.isnan(bundle_id).sum() == 0
        # assert torch.isnan(feature).sum() == 0
        # assert torch.isnan(seq_len_mask).sum() == 0

        content_id[torch.isnan(content_id)] = 0
        bundle_id[torch.isnan(bundle_id)] = 0
        feature[torch.isnan(feature)] = 0
        seq_len_mask[torch.isnan(seq_len_mask)] = 0

        pred, hiddens = self.forward(content_id=content_id,
                                     bundle_id=bundle_id,
                                     feature=feature,
                                     user_id=user_id,
                                     mask=seq_len_mask,
                                     initial_state=hiddens,
                                     ans_prev_correctly=ans_prev_correctly)
        pred = pred.view(batch_size * seq_len)

        flatten_mask = (seq_len_mask & question_mask).view(batch_size *
                                                           seq_len)
        loss = self.criterion(input=pred,
                              target=actual,
                              weight=flatten_mask,
                              reduction="sum")

        if self.hparams.get("b_flooding") is not None:
            b = torch.tensor(self.hparams["b_flooding"], dtype=torch.float)
            loss = torch.abs(loss - b) + b

        if calculate_roc:
            auc_score = auroc(pred=pred,
                              target=actual,
                              sample_weight=flatten_mask)
            return loss, hiddens, auc_score
        else:
            return loss, hiddens
Esempio n. 7
0
 def compute(self):
     preds = torch.cat(self.all_preds).view(-1)
     target = torch.cat(self.all_target).view(-1)
     try:
         return auroc(preds, target)
     except ValueError:
         logging.warning(
             'AUROC requires both negative and positive samples. Returning None'
         )
Esempio n. 8
0
    def _step(self, batch, hiddens=None, calculate_roc=False):
        actual: torch.Tensor = batch["y"].clone()  # (batch, seq)

        seen_content_feedback = self.__class__.to_seen_content_feedback(actual)

        seq_len_mask = batch["seq_len_mask"]  # (batch, seq)
        question_mask = batch["question_mask"]  # (batch, seq)

        batch_size, seq_len = actual.shape
        actual = actual.view(batch_size * seq_len).float()
        actual[actual < 0] = 0

        query_content_id = batch["content_id"]
        query_content_feature = batch["feature"]

        query_content_id[torch.isnan(query_content_id)] = 0
        query_content_feature[torch.isnan(query_content_feature)] = 0
        seq_len_mask[torch.isnan(seq_len_mask)] = 0

        seen_content_id = self.__class__._shift_tensor(query_content_id)
        seen_content_feature = self.__class__._shift_tensor(query_content_feature)

        pred, hiddens = self.forward(query_content_id=query_content_id,
                                     query_content_feature=query_content_feature,
                                     mask=seq_len_mask,
                                     seen_content_id=seen_content_id,
                                     seen_content_feature=seen_content_feature,
                                     seen_content_feedback=seen_content_feedback,
                                     initial_state=hiddens)
        pred = pred.view(batch_size * seq_len)

        flatten_mask = (seq_len_mask & question_mask).view(batch_size * seq_len)
        loss = self.criterion(input=pred,
                              target=actual,
                              weight=flatten_mask,
                              reduction="sum")

        if self.hparams.get("b_flooding") is not None:
            b = torch.tensor(self.hparams["b_flooding"], dtype=torch.float)
            loss = torch.abs(loss - b) + b

        if calculate_roc:
            auc_score = auroc(pred=pred,
                              target=actual,
                              sample_weight=flatten_mask)
            return loss, hiddens, auc_score
        else:
            return loss, hiddens
Esempio n. 9
0
    def forward(self,
                pred: torch.Tensor,
                target: torch.Tensor,
                sample_weight: Optional[Sequence] = None) -> torch.Tensor:
        """
        Actual metric computation

        Args:
            pred: predicted labels
            target: groundtruth labels
            sample_weight: the weights per sample

        Return:
            torch.Tensor: classification score
        """
        return auroc(pred=pred,
                     target=target,
                     sample_weight=sample_weight,
                     pos_label=self.pos_label)
Esempio n. 10
0
    def training_epoch_end(self, outputs):
        labels = []
        predictions = []

        for output in outputs:
            for out_labels in output["labels"].detach().cpu():
                labels.append(out_labels)

        for output in outputs:
            for out_preds in output["predictions"].detach().cpu():
                predictions.append(out_preds)

        labels = torch.stack(labels)
        predictions = torch.stack(predictions)

        for i, name in enumerate(LABEL_COLUMNS):
            roc_score = auroc(predictions[:, i], labels[:, i])
            self.logger.experiment.add_scalar(f"{name}_roc_auc/Train",
                                              roc_score, self.current_epoch)
Esempio n. 11
0
def get_all_auc_scores(pl_module: LightningModule, dataloader: DataLoader, minority: int) -> Dict[str, float]:
    """Computes different AUC scores and the accuracy of the given module on
    the given dataset.

    Args:
        pl_module: Model to evaluate the metrics on. 
        dataloader: Dataloader instance used to pass the dataset through the 
            model.
        minority: Index of the protected group with the fewest members.

    Returns:
        A dict mapping keys to the corresponding metric values.
    """
    
    # iterate through dataloader to generate predictions
    predictions: List[torch.Tensor] = []
    memberships: List[torch.Tensor] = []
    targets: List[torch.Tensor] = []
    for x, y, s in iter(dataloader):
        x = x.to(pl_module.device)
        # y and s are simple scalars, no need to move to GPU
        batch_predictions = torch.sigmoid(pl_module(x))
        predictions.append(batch_predictions)
        memberships.append(s)
        targets.append(y)

    prediction_tensor = torch.cat(predictions, dim=0)
    target_tensor = torch.cat(targets, dim=0).to(prediction_tensor.device)
    membership_tensor = torch.cat(memberships, dim=0).to(prediction_tensor.device)

    aucs = group_aucs(prediction_tensor, target_tensor, membership_tensor)
    acc = torch.mean(((prediction_tensor > 0.5).int() == target_tensor).float()).item()
    
    results = {
        'min_auc': min(aucs.values()),
        'macro_avg_auc': mean(aucs.values()),
        'micro_avg_auc': auroc(prediction_tensor, target_tensor).item(),
        'minority_auc': aucs[minority],
        'accuracy': acc
    }
    
    return results
def test_auroc(pred, target, expected):
    score = auroc(torch.tensor(pred), torch.tensor(target)).item()
    assert score == expected
 def compute(self) -> torch.Tensor:
     preds, targets = self._get_preds_and_targets()
     if torch.unique(targets).numel() == 1:
         return torch.tensor(np.nan)
     return auroc(preds, targets)
Esempio n. 14
0
def get_metrics(network, loader, weights, device, name, mode='full'):
    print('Start Evaluation')
    correct = 0
    strict_correct = 0
    total = 0
    strict_total = 0
    weights_offset = 0
    ys = []
    ps = []
    sigmoid = nn.Sigmoid()

    network.eval()
    with torch.no_grad():
        t = tqdm(iter(loader), leave=False, total=len(loader))
        for i, data in enumerate(t):
            x, y = data
            if mode == 'skip' and i >= 100:
                break
            x = x.to(device)
            y = y.to(device)
            p = sigmoid(network.predict(x))
            ys.append(y)
            ps.append(p)
            if weights is None:
                batch_weights = torch.ones(len(x))
            else:
                batch_weights = weights[weights_offset:weights_offset + len(x)]
                weights_offset += len(x)
            batch_weights = batch_weights.to(device)
            strict_correct += ((p.gt(.5) == y).all().float() *
                               batch_weights.reshape((-1, 1))).sum().item()
            correct += ((p.gt(.5) == y).float() * batch_weights.reshape(
                (-1, 1))).sum().item()
            total += p.size(0) * p.size(1)
            strict_total += batch_weights.sum().item()
        ps = torch.cat(ps).to(device)
        ys = torch.cat(ys).to(device)
        eces = get_ece(ps, ys).item()
        # micro_f1 = f1_score(ps.gt(.5).float(), ys, num_classes=None, class_reduction='micro').item()
        # macro_f1 = f1_score(ps.gt(.5).float(), ys, num_classes=None, class_reduction='macro').item()
        aucs = []
        micro_f1 = []
        macro_f1 = []
        for d in range(ps.size(1)):
            micro = F1(num_classes=2, average='micro')
            macro = F1(num_classes=2, average='macro')
            micro_f1.append(
                micro(ps[:, d].gt(.5).cpu().long(), ys[:,
                                                       d].cpu().long()).item())
            macro_f1.append(
                macro(ps[:, d].gt(.5).cpu().long(), ys[:,
                                                       d].cpu().long()).item())
            aucs.append(auroc(ps[:, d], ys[:, d]).item())
    network.train()
    results = {
        f'{name}_acc': correct / total,
        f'{name}_strict_acc': strict_correct / strict_total,
        f'{name}_auc': aucs,
        f'{name}_micro_f1': micro_f1,
        f'{name}_macro_f1': macro_f1,
        f'{name}_eces': eces
    }
    return results
Esempio n. 15
0
def compute_evaluation_metrics(outputs: List[List[torch.Tensor]],
                               plot: bool = False,
                               prefix: Optional[str] = None) -> Dict[str, torch.Tensor]:
    scores = torch.cat(list((scores for step in outputs for scores in step[0])))
    # NOTE: Need sigmoid here because we skip the sigmoid in forward() due to using BCE with logits for loss.
    #scores = torch.sigmoid(scores)
    print('Score range: [{}, {}]'
          .format(torch.min(scores).item(),
                  torch.max(scores).item()))
    labels = torch.cat(list((labels for step in outputs for labels in step[1])))

    auc = auroc(scores, labels, pos_label=1)
    fpr, tpr, thresholds = roc(scores, labels, pos_label=1)
    prec, recall = precision_recall(scores, labels)

    # mypy massaging, single tensors when num_classes is not specified (= binary case).
    fpr = cast(torch.Tensor, fpr)
    tpr = cast(torch.Tensor, tpr)
    thresholds = cast(torch.Tensor, thresholds)

    fnr = 1 - tpr
    eer, eer_threshold, idx = equal_error_rate(fpr, fnr, thresholds)
    min_dcf, min_dcf_threshold = minDCF(fpr, fnr, thresholds)

    # Accuracy based on EER and minDCF thresholds.
    eer_preds = (scores >= eer_threshold).long()
    min_dcf_preds = (scores >= min_dcf_threshold).long()
    eer_acc = torch.sum(eer_preds == labels).float() / labels.numel()
    min_dcf_acc = torch.sum(min_dcf_preds == labels).float() / labels.numel()

    if plot:
        assert idx.dim() == 0 or (idx.dim() == 1 and idx.size(0) == 1)
        i = int(idx.item())
        fpr = fpr.cpu().numpy()
        tpr = tpr.cpu().numpy()
        plt.xlabel('False positive rate')
        plt.ylabel('True positive rate')
        plt.plot([0, 1], [0, 1], 'r--', label='Reference', alpha=0.6)
        plt.plot([1, 0], [0, 1], 'k--', label='EER line', alpha=0.6)
        plt.plot(fpr, tpr, label='ROC curve')
        plt.fill_between(fpr, tpr, 0, label='AUC', color='0.8')
        plt.plot(fpr[i], tpr[i], 'ko', label='EER = {:.2f}%'.format(eer * 100))  # EER point
        plt.legend()
        plt.show()

    if prefix:
        prefix = '{}_'.format(prefix)
    else:
        prefix = ''

    return {
        '{}eer'.format(prefix): eer,
        '{}eer_acc'.format(prefix): eer_acc,
        '{}eer_threshold'.format(prefix): eer_threshold,
        '{}auc'.format(prefix): auc,
        '{}min_dcf'.format(prefix): min_dcf,
        '{}min_dcf_acc'.format(prefix): min_dcf_acc,
        '{}min_dcf_threshold'.format(prefix): min_dcf_threshold,
        '{}prec'.format(prefix): prec,
        '{}recall'.format(prefix): recall
    }