コード例 #1
0
 def __init__(self):
     super(ClassifierBackBone, self).__init__()
     self.back_bone = nn.Sequential(nn.Conv2d(3, 32, (7, 7), stride=(2, 2)),
                                    ResidualBlock(32, 32),
                                    ResidualBlock(32, 64),
                                    ResidualBottleneck(64, 2),
                                    ResidualBlock(64, 64),
                                    ResidualBlock(64, 128),
                                    ResidualBottleneck(128, 2),
                                    ResidualBlock(128, 128),
                                    ResidualBlock(128, 256),
                                    ResidualBottleneck(256, 2),
                                    ResidualBlock(256, 256), nn.Flatten(),
                                    nn.Linear(256 * 25 * 25, 1),
                                    nn.Sigmoid())
     self.criterion = torch.nn.BCELoss()
     self.train_metrics = MetricCollection({
         'train_accuracy':
         Accuracy(compute_on_step=False),
         'train_precision':
         Precision(compute_on_step=False),
         'train_recall':
         Recall(compute_on_step=False),
     })
     self.val_metrics = MetricCollection({
         'val_accuracy':
         Accuracy(compute_on_step=False),
         'val_precision':
         Precision(compute_on_step=False),
         'val_recall':
         Recall(compute_on_step=False)
     })
コード例 #2
0
    def __init__(
            self,
            arch: str,
            optcfg: DictConfig,
            arch_ckpt: Optional[str] = None,
            schcfg: Optional[DictConfig] = None,
            **kwargs,
    ):
        super().__init__()

        self.schcfg = schcfg
        self.optcfg = optcfg
        self.save_hyperparameters()

        if arch_ckpt: 
            arch = arch_ckpt
        self.transformer = AutoModelForSequenceClassification.from_pretrained(arch, num_labels=7)

        # loss function
        self.criterion = nn.CrossEntropyLoss()

        # metrics
        mc = MetricCollection({
            "accuracy": Accuracy(threshold=0.0),
            "recall": Recall(threshold=0.0, num_classes=7, average='macro'),
            "precision": Precision(threshold=0.0, num_classes=7, average='macro'),
            "f1": F1(threshold=0.0, num_classes=7, average='macro'),
            "macro_auc": AUROC(num_classes=7, average='macro'),
            # "weighted_auc": AUROC(num_classes=7, average='weighted')
        })
        self.metrics: ModuleDict[str, MetricCollection] = ModuleDict({
            f"{phase}_metric": mc.clone()
            for phase in ["train", "valid", "test"]
        })
コード例 #3
0
    def __init__(self,
                 num_classes: int = 2,
                 ignore_index: Optional[int] = None,
                 lr: float = 0.35,
                 weight_decay: float = 0):
        super().__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.lr = lr
        self.weight_decay = weight_decay

        # Create model from pre-trained DeepLabv3
        self.model = lraspp_mobilenet_v3_large(progress=True,
                                               num_classes=self.num_classes)
        self.model.requires_grad_(True)

        # Loss function
        self.focal_tversky_loss = FocalTverskyMetric(
            self.num_classes,
            alpha=0.7,
            beta=0.3,
            gamma=4.0 / 3.0,
            ignore_index=self.ignore_index)
        self.accuracy_metric = Accuracy(ignore_index=self.ignore_index)
        self.iou_metric = JaccardIndex(num_classes=self.num_classes,
                                       reduction="none",
                                       ignore_index=self.ignore_index)
        self.precision_metric = Precision(num_classes=self.num_classes,
                                          ignore_index=self.ignore_index,
                                          average='weighted',
                                          mdmc_average='global')
        self.recall_metric = Recall(num_classes=self.num_classes,
                                    ignore_index=self.ignore_index,
                                    average='weighted',
                                    mdmc_average='global')
コード例 #4
0
 def configure_metrics(self, _) -> None:
     self.prec = Precision(num_classes=self.num_classes)
     self.recall = Recall(num_classes=self.num_classes)
     self.acc = Accuracy()
     self.metrics = {
         "precision": self.prec,
         "recall": self.recall,
         "accuracy": self.acc
     }
コード例 #5
0
def get_metrics_collections_base(NUM_CLASS, prefix):

    metrics = MetricCollection(
        {
            "Accuracy": Accuracy(),
            "Top_3": Accuracy(top_k=3),
            "Top_5": Accuracy(top_k=5),
            "Precision_micro": Precision(num_classes=NUM_CLASS,
                                         average="micro"),
            "Precision_macro": Precision(num_classes=NUM_CLASS,
                                         average="macro"),
            "Recall_micro": Recall(num_classes=NUM_CLASS, average="micro"),
            "Recall_macro": Recall(num_classes=NUM_CLASS, average="macro"),
            "F1_micro": torchmetrics.F1(NUM_CLASS, average="micro"),
            "F1_macro": torchmetrics.F1(NUM_CLASS, average="micro"),
        },
        prefix=prefix)

    return metrics
コード例 #6
0
 def __init__(self, num_classes):
     self.metrics = [
         ("acc", Accuracy(num_classes=num_classes, average="micro")),
         ("f1", F1Score(num_classes=num_classes, average="micro")),
         ("precision", Precision(num_classes=num_classes, average="micro")),
         ("recall", Recall(num_classes=num_classes, average="micro")),
     ]
     if num_classes > 2:
         self.metrics += [
             ("macro_acc", Accuracy(num_classes=num_classes,
                                    average="macro")),
             ("macro_f1", F1Score(num_classes=num_classes,
                                  average="macro")),
             (
                 "macro_precision",
                 Precision(num_classes=num_classes, average="macro"),
             ),
             ("macro_recall",
              Recall(num_classes=num_classes, average="macro")),
         ]
コード例 #7
0
def get_metrics_collections_base(NUM_CLASS,
                            # device="cuda" if torch.cuda.is_available() else "cpu",
                            
                            ):
    
    metrics = MetricCollection(
            {
                "Accuracy":Accuracy(),
                "Top_3":Accuracy(top_k=3),
                "Top_5" :Accuracy(top_k=5),
                "Precision_micro":Precision(num_classes=NUM_CLASS,average="micro"),
                "Precision_macro":Precision(num_classes=NUM_CLASS,average="macro"),
                "Recall_micro":Recall(num_classes=NUM_CLASS,average="micro"),
                "Recall_macro":Recall(num_classes=NUM_CLASS,average="macro"),
                "F1_micro":torchmetrics.F1(NUM_CLASS,average="micro"),
                "F1_macro":torchmetrics.F1(NUM_CLASS,average="micro"),
            }
            )
    
    
    return metrics
コード例 #8
0
def get_metrics(metric_threshold, monitor_metrics, num_classes):
    macro_prec = Precision(num_classes, metric_threshold, average='macro')
    macro_recall = Recall(num_classes, metric_threshold, average='macro')
    another_macro_f1 = 2 * (macro_prec * macro_recall) / (macro_prec +
                                                          macro_recall + 1e-10)
    metrics = {
        'Micro-Precision':
        Precision(num_classes, metric_threshold, average='micro'),
        'Micro-Recall':
        Recall(num_classes, metric_threshold, average='micro'),
        'Micro-F1':
        F1(num_classes, metric_threshold, average='micro'),
        'Macro-F1':
        F1(num_classes, metric_threshold, average='macro'),
        # The f1 value of macro_precision and macro_recall. This variant of
        # macro_f1 is less preferred but is used in some works. Please
        # refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf]
        'Another-Macro-F1':
        another_macro_f1,
    }
    for metric in monitor_metrics:
        if isinstance(metric, Metric):  # customized metric
            metrics[type(metric).__name__] = metric
        elif re.match('P@\d+', metric):
            metrics[metric] = Precision(num_classes,
                                        average='samples',
                                        top_k=int(metric[2:]))
        elif re.match('R@\d+', metric):
            metrics[metric] = Recall(num_classes,
                                     average='samples',
                                     top_k=int(metric[2:]))
        elif metric not in [
                'Micro-Precision', 'Micro-Recall', 'Micro-F1', 'Macro-F1',
                'Another-Macro-F1'
        ]:
            raise ValueError(f'Invalid metric: {metric}')

    return MetricCollection(metrics)
コード例 #9
0
ファイル: train_model.py プロジェクト: Tobias-Fischer/rt_gene
    def __init__(self,
                 hparams,
                 train_subjects,
                 validate_subjects,
                 class_weights=None):
        super(TrainRTBENE, self).__init__()
        assert class_weights is not None, "Class Weights can't be None"

        self.model = MODELS[hparams.model_base]()
        self._criterion = torch.nn.BCEWithLogitsLoss(
            pos_weight=torch.Tensor([class_weights[1]]))
        self._train_subjects = train_subjects
        self._validate_subjects = validate_subjects
        self._metrics = MetricCollection(
            [Accuracy(),
             F1(), Precision(),
             Recall(), Specificity()])
        self.save_hyperparameters(
            hparams,
            ignore=["train_subjects", "validate_subjects", "class_weights"])
コード例 #10
0
ファイル: lit_unet.py プロジェクト: tayden/uav-classif
    def __init__(self, hparams):
        """hparams must be a dict of {weight_decay, lr, num_classes}"""
        super().__init__()
        self.save_hyperparameters(hparams)

        # Create model from pre-trained DeepLabv3
        self.model = Unet(
            encoder_name="efficientnet-b4",
            encoder_weights="imagenet",
            in_channels=3,
            classes=self.hparams.num_classes,
        )
        self.model.requires_grad_(True)
        self.model.encoder.requires_grad_(False)

        # Loss function and metrics
        self.focal_tversky_loss = FocalTverskyMetric(
            self.hparams.num_classes,
            alpha=0.7,
            beta=0.3,
            gamma=4.0 / 3.0,
            ignore_index=self.hparams.get("ignore_index"),
        )
        self.accuracy_metric = Accuracy(
            ignore_index=self.hparams.get("ignore_index"))
        self.iou_metric = JaccardIndex(
            num_classes=self.hparams.num_classes,
            reduction="none",
            ignore_index=self.hparams.get("ignore_index"),
        )
        self.precision_metric = Precision(num_classes=self.num_classes,
                                          ignore_index=self.ignore_index,
                                          average='weighted',
                                          mdmc_average='global')
        self.recall_metric = Recall(num_classes=self.num_classes,
                                    ignore_index=self.ignore_index,
                                    average='weighted',
                                    mdmc_average='global')
コード例 #11
0
    def __init__(self, num_classes: int = 2, ignore_index: Optional[int] = None, lr: float = 0.001,
                 weight_decay: float = 0.001, aux_loss_factor: float = 0.3):

        super().__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.lr = lr
        self.weight_decay = weight_decay
        self.aux_loss_factor = aux_loss_factor

        # Create model from pre-trained DeepLabv3
        self.model = deeplabv3_resnet101(pretrained=True, progress=True)
        self.model.aux_classifier = FCNHead(1024, self.num_classes)
        self.model.classifier = DeepLabHead(2048, self.num_classes)

        # Setup trainable layers
        self.model.requires_grad_(True)
        self.model.backbone.requires_grad_(False)

        # Loss function and metrics
        self.focal_tversky_loss = FocalTverskyMetric(
            self.num_classes,
            alpha=0.7,
            beta=0.3,
            gamma=4.0 / 3.0,
            ignore_index=self.ignore_index,
        )
        self.accuracy_metric = Accuracy(ignore_index=self.ignore_index)
        self.iou_metric = JaccardIndex(
            num_classes=self.num_classes,
            reduction="none",
            ignore_index=self.ignore_index,
        )
        self.precision_metric = Precision(num_classes=self.num_classes, ignore_index=self.ignore_index,
                                          average='weighted', mdmc_average='global')
        self.recall_metric = Recall(num_classes=self.num_classes, ignore_index=self.ignore_index,
                                    average='weighted', mdmc_average='global')
コード例 #12
0
ファイル: test_bootstrapping.py プロジェクト: Borda/metrics
        assert ns in old_samples

    found_one = _sample_checker(old_samples, new_samples, operator.eq, 2)
    assert found_one, "resampling did not work because no samples were sampled twice"

    found_zero = _sample_checker(old_samples, new_samples, operator.ne, 0)
    assert found_zero, "resampling did not work because all samples were atleast sampled once"


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"])
@pytest.mark.parametrize(
    "metric, sk_metric",
    [
        [
            Precision(average="micro"),
            partial(precision_score, average="micro")
        ],
        [Recall(average="micro"),
         partial(recall_score, average="micro")],
        [MeanSquaredError(), mean_squared_error],
    ],
)
def test_bootstrap(device, sampling_strategy, metric, sk_metric):
    """Test that the different bootstraps gets updated as we expected and that the compute method works."""
    if device == "cuda" and not torch.cuda.is_available():
        pytest.skip("Test with device='cuda' requires gpu")

    _kwargs = {
        "base_metric": metric,
        "mean": True,
コード例 #13
0
ファイル: metrics.py プロジェクト: ASUS-AICS/LibMultiLabel
def get_metrics(metric_threshold, monitor_metrics, num_classes):
    """Map monitor metrics to the corresponding classes defined in `torchmetrics.Metric`
    (https://torchmetrics.readthedocs.io/en/latest/references/modules.html).

    Args:
        metric_threshold (float): Threshold to monitor for metrics.
        monitor_metrics (list): Metrics to monitor while validating.
        num_classes (int): Total number of classes.

    Raises:
        ValueError: The metric is invalid if:
            (1) It is not one of 'P@k', 'R@k', 'RP@k', 'nDCG@k', 'Micro-Precision',
                'Micro-Recall', 'Micro-F1', 'Macro-F1', 'Another-Macro-F1', or a
                `torchmetrics.Metric`.
            (2) Metric@k: k is greater than `num_classes`.

    Returns:
        torchmetrics.MetricCollection: A collections of `torchmetrics.Metric` for evaluation.
    """
    if monitor_metrics is None:
        monitor_metrics = []

    metrics = dict()
    for metric in monitor_metrics:
        if isinstance(metric, Metric):  # customized metric
            metrics[type(metric).__name__] = metric
            continue

        match_top_k = re.match(r'\b(P|R|RP|nDCG)\b@(\d+)', metric)
        match_metric = re.match(r'\b(Micro|Macro)\b-\b(Precision|Recall|F1)\b',
                                metric)

        if match_top_k:
            metric_abbr = match_top_k.group(1)  # P, R, PR, or nDCG
            top_k = int(match_top_k.group(2))
            if top_k >= num_classes:
                raise ValueError(
                    f'Invalid metric: {metric}. {top_k} is greater than {num_classes}.'
                )
            if metric_abbr == 'P':
                metrics[metric] = Precision(num_classes,
                                            average='samples',
                                            top_k=top_k)
            elif metric_abbr == 'R':
                metrics[metric] = Recall(num_classes,
                                         average='samples',
                                         top_k=top_k)
            elif metric_abbr == 'RP':
                metrics[metric] = RPrecision(top_k=top_k)
            elif metric_abbr == 'nDCG':
                metrics[metric] = NDCG(top_k=top_k)
                # The implementation in torchmetrics stores the prediction/target of all batches,
                # which can lead to CUDA out of memory.
                # metrics[metric] = RetrievalNormalizedDCG(k=top_k)
        elif metric == 'Another-Macro-F1':
            metrics[metric] = MacroF1(num_classes,
                                      metric_threshold,
                                      another_macro_f1=True)
        elif metric == 'Macro-F1':
            metrics[metric] = MacroF1(num_classes, metric_threshold)
        elif match_metric:
            average_type = match_metric.group(1).lower()  # Micro
            metric_type = match_metric.group(2)  # Precision, Recall, or F1
            metrics[metric] = getattr(torchmetrics.classification,
                                      metric_type)(num_classes,
                                                   metric_threshold,
                                                   average=average_type)
        else:
            raise ValueError(
                f'Invalid metric: {metric}. Make sure the metric is in the right format: Macro/Micro-Precision/Recall/F1 (ex. Micro-F1)'
            )

    return MetricCollection(metrics)
コード例 #14
0
def run_epoch(model, dataloader, criterion, optimizer=None, epoch=0, scheduler=None, device='cpu'):
    import torchmetrics.functional as clmetrics
    from torchmetrics import Precision, Accuracy, Recall
    #import pytorch_lightning.metrics.functional.classification as clmetrics
    #from pytorch_lightning.metrics import Precision, Accuracy, Recall
    from sklearn.metrics import roc_auc_score, average_precision_score

    metrics = Accumulator()
    cnt = 0
    total_steps = len(dataloader)
    steps = 0
    running_corrects = 0
    

    accuracy = Accuracy()
    precision = Precision(num_classes=2)
    recall = Recall(num_classes=2)

    preds_epoch = []
    labels_epoch = []
    for inputs, labels in dataloader:
        steps += 1
        inputs = inputs.to(device) # torch.Size([2, 1, 224, 224])
        labels = labels.to(device).unsqueeze(1).float() ## torch.Size([2, 1])

        outputs = model(inputs) # [batch_size, nb_classes]

        loss = criterion(outputs, labels)

        if optimizer:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        preds_epoch.extend(torch.sigmoid(outputs).tolist())
        labels_epoch.extend(labels.tolist())
        threshold = 0.5
        prob = (torch.sigmoid(outputs)>threshold).long()
        
        conf = torch.flatten(clmetrics.confusion_matrix(prob, labels.to(prob.device, dtype=torch.int), num_classes=2))
        tn, fp, fn, tp = conf

        metrics.add_dict({
            'data_count': len(inputs),
            'loss': loss.item() * len(inputs),
            'tp': tp.item(),
            'tn': tn.item(),
            'fp': fp.item(),
            'fn': fn.item(),
        })
        cnt += len(inputs)

        if scheduler:
            scheduler.step()
        del outputs, loss, inputs, labels, prob
    logger.info(f'cnt = {cnt}')

    metrics['loss'] /= cnt

    def safe_div(x,y):
        if y == 0:
            return 0
        return x / y
    _TP,_TN, _FP, _FN = metrics['tp'], metrics['tn'], metrics['fp'], metrics['fn']
    acc = (_TP+_TN)/cnt
    sen = safe_div(_TP , (_TP + _FN))
    spe = safe_div(_TN , (_FP + _TN))
    prec = safe_div(_TP , (_TP + _FP))
    metrics.add('accuracy', acc)
    metrics.add('sensitivity', sen)
    metrics.add('specificity', spe)
    metrics.add('precision', prec)
    
    try:
        auc = roc_auc_score(labels_epoch, preds_epoch)
    except ValueError:
      auc = 0.
      print('ValueError. set auc = 0')
      pass
    try:
        aupr = average_precision_score(labels_epoch, preds_epoch)
    except ValueError:
      aupr = 0.
      print('ValueError. set aupr = 0')
      pass
    metrics.add('auroc', auc)
    metrics.add('aupr', aupr)

    logger.info(metrics)

    return metrics, preds_epoch, labels_epoch
コード例 #15
0
ファイル: util.py プロジェクト: parthjindal/lotteryFL_cell
        Returns a copy of the input model.
        Note: the model should have been pruned for this method to work to create buffer masks and what not.
    """
    new_model = create_model(model.__class__, device)
    source_params = dict(model.named_parameters())
    source_buffer = dict(model.named_buffers())
    for name, param in new_model.named_parameters():
        param.data.copy_(source_params[name].data)
    for name, buffer_ in new_model.named_buffers():
        buffer_.data.copy_(source_buffer[name].data)
    return new_model


metrics = MetricCollection([
    Accuracy(),
    Precision(),
    Recall(),
    F1(),
])


def train(model: nn.Module,
          train_dataloader: DataLoader,
          lr: float = 1e-3,
          device: str = 'cuda:0',
          fast_dev_run=False,
          verbose=True) -> Dict[str, torch.Tensor]:

    optimizer = torch.optim.Adam(lr=lr, params=model.parameters())
    loss_fn = nn.CrossEntropyLoss()
    num_batch = len(train_dataloader)
コード例 #16
0
    ],
)
def test_raises_error_if_increment_not_called(method, method_input):
    tracker = MetricTracker(Accuracy(num_classes=10))
    with pytest.raises(ValueError, match=f"`{method}` cannot be called before .*"):
        if method_input is not None:
            getattr(tracker, method)(*method_input)
        else:
            getattr(tracker, method)()


@pytest.mark.parametrize(
    "base_metric, metric_input, maximize",
    [
        (Accuracy(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
        (Precision(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
        (Recall(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True),
        (MeanSquaredError(), (torch.randn(50), torch.randn(50)), False),
        (MeanAbsoluteError(), (torch.randn(50), torch.randn(50)), False),
        (
            MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]),
            (torch.randint(10, (50,)), torch.randint(10, (50,))),
            True,
        ),
        (
            MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]),
            (torch.randint(10, (50,)), torch.randint(10, (50,))),
            [True, True, True],
        ),
        (MetricCollection([MeanSquaredError(), MeanAbsoluteError()]), (torch.randn(50), torch.randn(50)), False),
        (
コード例 #17
0
def test_raises_error_if_increment_not_called(method, method_input):
    tracker = MetricTracker(Accuracy(num_classes=10))
    with pytest.raises(ValueError,
                       match=f"`{method}` cannot be called before .*"):
        if method_input is not None:
            getattr(tracker, method)(*method_input)
        else:
            getattr(tracker, method)()


@pytest.mark.parametrize(
    "base_metric, metric_input, maximize",
    [
        (Accuracy(num_classes=10),
         (torch.randint(10, (50, )), torch.randint(10, (50, ))), True),
        (Precision(num_classes=10),
         (torch.randint(10, (50, )), torch.randint(10, (50, ))), True),
        (Recall(num_classes=10),
         (torch.randint(10, (50, )), torch.randint(10, (50, ))), True),
        (MeanSquaredError(), (torch.randn(50), torch.randn(50)), False),
        (MeanAbsoluteError(), (torch.randn(50), torch.randn(50)), False),
        (
            MetricCollection([
                Accuracy(num_classes=10),
                Precision(num_classes=10),
                Recall(num_classes=10)
            ]),
            (torch.randint(10, (50, )), torch.randint(10, (50, ))),
            True,
        ),
        (